大模型中top-p&top-k&temperature如何共同使用——Gemma为例子

参考:
大模型文本生成——解码策略(Top-k & Top-p & Temperature)
大模型源码理解-以Gemma为例子

摘要

之前系统学习了大模型的解码方式,Top-p, Top-k, Beam-search, Greedy, temperature等等,具体使用的时候,也清楚采用这些方式混合使用,但是具体怎么混合,有些模糊。看了一篇相关文章大模型文本生成——解码策略(Top-k & Top-p & Temperature),如下图所示,解决了我一些理解方面的问题,但是感觉还有有些模糊,仔细研究了一下Gemma,记录一下。

结论

先说一下结论,Gemma是怎么解码的,主要研究top-p & top-k &temperature是如何使用的,temperature>top-p>top-k(其实top-p和top-k可以算作并行,同时使用)。与上图的结论略有区别,可能不同得模型策略方便略有区别吧。

代码分析

通过代码进行分析,是如何实现三种策略混合使用。
我觉得,top-p和top-k,一起使用,保留同时满足top-p&top-k的概率值。

class Sampler(nn.Module):

    def __init__(self, vocab_size: int):
        super().__init__()
        self.vocab_size = vocab_size

    @torch.no_grad()
    def forward(
        self,
        embedding: torch.Tensor,
        hidden_states: torch.Tensor,
        output_positions: torch.Tensor,
        temperatures: Union[torch.Tensor, None],
        top_ps: torch.Tensor,
        top_ks: torch.Tensor,
        embedding_bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Select the last element for each sequence.
        # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
        hidden_states = hidden_states.index_select(#
            1, output_positions).squeeze(dim=1)
        logits = torch.matmul(hidden_states, embedding.t())#计算不同得Token得分情况
        if embedding_bias is not None:#是否增加bisa,这个无关紧要
            logits += embedding_bias

        if temperatures is None:#temperature为空则设置贪婪匹配则选择最大得概率,GPT中好像是0进行贪婪匹配
            return torch.argmax(logits, dim=-1).squeeze(dim=-1)

        # Apply temperature scaling.
        logits.div_(temperatures.unsqueeze(dim=1))#预测得结果去除以temperature,修改分布

        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)#进行softmax归一化
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)#从大到小进行排序

        # Apply top-p, top-k.
		#这里进行top-p
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)
		
		#这里进行top-k
        top_ks_mask = torch.arange(probs_idx.shape[-1],
                                   device=probs_idx.device)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
      	#这里top-p和top-k同时起作用,同时满足top-p和top-k得结果才有概率值,否则就为0
      	#因为top_ks_mask判断大于top_ks得为True, 大于得不是我们得范围,所以top_ks_mask就补0, 不为True,则补top-p得结果
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)

        # Re-normalization.
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))

        next_token_ids = torch.multinomial(probs,
                                           num_samples=1,
                                           replacement=True).squeeze(dim=-1)
        return next_token_ids
### 大模型微调时参数温度和top_p的作用 #### 温度(Temperature) 温度是一个在应用 `softmax` 函数以获得概率之前对对数几率(logits)进行缩放的参数。它主要用于控制生成文本的随机性和创造性[^1]。 - **较高温度值**:当温度值较大(如 1 或 2),会增加对数几率的方差,使得概率分布更加均匀,从而导致生成的文本更具多样性和不可预测性。 - **较低温度值**:当温度值较小(如 0.5),会使概率分布更集中于高概率词元上,因此生成的文本更具针对性和确定性。 以下是基于不同温度值的效果对比代码示例: ```python import torch import torch.nn.functional as F def apply_temperature(logits, temperature): scaled_logits = logits / temperature probabilities = F.softmax(scaled_logits, dim=-1) return probabilities # 原始 logits 输出 logits = torch.tensor([1.0, 2.0, 3.0]) print("High Temperature (e.g., T=2):", apply_temperature(logits, 2)) print("Low Temperature (e.g., T=0.5):", apply_temperature(logits, 0.5)) ``` #### Top-p(Nucleus Sampling) Top-p 是一种采样策略,也称为核采样(nucleus sampling)。它的作用是从累积概率达到指定阈值 p 的最小候选集合中抽取下一个 token。这种方法可以有效平衡多样性与质量之间的关系[^2]。 - 当设定较大的 top-p 值(接近 1),意味着允许更多的低概率词元被选中,这增加了生成文本的多样性。 - 当设定较小的 top-p 值(如 0.8 或更低),则倾向于选择更高概率的词元,从而使生成的结果更加稳定和可控。 以下是一个实现 Top-p 采样的 Python 示例: ```python def nucleus_sampling(probs, p): sorted_probs, indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # 找到满足条件的最大索引位置 cutoff_index = torch.searchsorted(cumulative_probs, p).item() sampled_indices = indices[:cutoff_index] # 归一化剩余概率并返回 normalized_probs = sorted_probs[:cutoff_index] / torch.sum(sorted_probs[:cutoff_index]) chosen_token_idx = torch.multinomial(normalized_probs, num_samples=1).item() return sampled_indices[chosen_token_idx].item() probs = torch.tensor([0.1, 0.2, 0.3, 0.4]) chosen_token = nucleus_sampling(probs, 0.9) print(f"Chosen Token Index with P=0.9: {chosen_token}") ``` #### 实际应用场景中的调整建议 在实际应用中,温度和 top-p 参数的选择取决于具体的任务需求以及期望生成文本的特点[^3]。例如: - 如果希望生成创造性强但又不失连贯性的内容,则可以选择适中的温度值(如 0.7~1.0)和较高的 top-p 阈值(如 0.9)。 - 若目标是生成高度可预测且精确的内容,则应降低温度值(如 0.2~0.5)并减少 top-p 范围(如 0.5~0.7)。 --- ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值