2025-09-21 从零构建大模型—文本生成策略
发布于 2025年09月21日 • 1 分钟 • 146 字
在解码的时候,生成的词元是从词汇表的所有词元中选择概率分数最大的那一个,也就是argmax最大的词元id,但是这种形式让大模型失去丰富性,因为多次运行大模型生成的文本是相同的。
两种技术(温度缩放和Top-k采样)可以用于文本生成的优化。
温度缩放
用一个从概率分布(这里是大语言模型在每个词元生成步骤为每个词汇条目生成的概率分数)中采样的函数来取代argmax。
这个概率采样函数Multinomial按照词汇表的概率分数采样下一个词元。
- argmax 永远挑概率最大的那个;
- multinomial 按概率分布随机抽签——大概率事件只是“中签率高”,并非 100 %。
torch.manual_seed(123)
next_token_id = torch.multinomial(probas, num_samples=1).item()
print(inverse_vocab[next_token_id])
通过一个称为温度缩放的概念,可以进一步控制分布和选择过程。温度缩放指的是将logits除以一个大于0的数。温度大于1会导致词元概率更加均匀分布,而小于1的温度将导致更加自信(更尖锐或更陡峭)的分布。
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
return torch.softmax(scaled_logits, dim=0)
应用非常小的温度(如0.1)会导致更集中的分布,使得multinomial函数几乎100%选择最可能的词元,接近于argmax函数的行为。温度增大会导致更均匀的分布,使得其他词元更容易被选中。这可以为生成的文本增加更多变化,但也更容易生成无意义的文本。
设置温度
vocab = {
"closer": 0,
"every": 1,
"effort": 2,
"forward": 3,
"inches": 4,
"moves": 5,
"pizza": 6,
"toward": 7,
"you": 8,
}
inverse_vocab = {v: k for k, v in vocab.items()}
next_token_logits = torch.tensor(
[4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]
)
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
return torch.softmax(scaled_logits, dim=0)
probas = softmax_with_temperature(next_token_logits, temperature=5)
def print_sampled_tokens(probas):
torch.manual_seed(123)
sample = [torch.multinomial(probas, num_samples=1).item()
for i in range(1_000)]
sampled_ids = torch.bincount(torch.tensor(sample))
for i, freq in enumerate(sampled_ids):
print(f"{freq} x {inverse_vocab[i]}")
print_sampled_tokens(probas)
Top K 采样
较高的温度值会让下一个词元的概率分布更加均匀,从而产生更加多样化的输出。但是有时候会产生语法不正确或者完全无意义的输出。
通过概率采样和温度缩放相结合,可以改善文本生成结果。
在Top-K 采样中,可以将采样的词元限制在前K个最可能的词元上。运用的方法是掩码:用负无穷值(-inf)替换所有未选择的logits,在计算softmax值时,非前K词元的概率分数为0,剩余的概率总和为1。
def top_k_tokens(next_token_logits, top_k=5,temperature=0.1):
top_logits, top_pos = torch.topk(next_token_logits, top_k)
print("Top logits:", top_logits)
new_logits = torch.where(
condition=next_token_logits< top_logits.min(),
input=torch.tensor(float('-inf')),
other=next_token_logits
)
topk_probas =torch.softmax(new_logits/temperature,dim=0)
print(topk_probas)
return topk_probas
只设置temperature=5和同时设置top_k=5 temperature=5的对比,可以看出生成的结果是更集中的多元(或者更多元的集中)。