稀疏Softmax(Sparse Softmax)

本文介绍了稀疏Softmax(SparseSoftmax)这一技术,源于《FromSoftmaxtoSparsemax》等,通过稀疏化Softmax分布以增强模型的可解释性和可能的性能提升。它避免了Softmax过度学习问题,特别是在分类任务中。文章详细解释了稀疏化原理并提供了两种实现版本,包括苏剑林简化版。但要注意, SparseSoftmax在预训练模型中适用以防过拟合,而在从零开始的模型中可能导致欠拟合。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文源自于SPACES:“抽取-生成”式长文本摘要(法研杯总结),原文其实是对一个比赛的总结,里面提到了很多Trick,其中有一个叫做稀疏Softmax(Sparse Softmax)的东西吸引了我的注意,查阅了很多资料以后,汇总在此

Sparse Softmax的思想源于《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》《Sparse Sequence-to-Sequence Models》等文章。里边作者提出了将Softmax稀疏化的做法来增强其解释性乃至提升效果

不够稀疏的Softmax

前面提到Sparse Softmax本质上是将Softmax的结果稀疏化,那么为什么稀疏化之后会有效呢?我们认稀疏化可以避免Softmax过度学习的问题。假设已经成功分类,那么我们有 s max = s t s_{\text{max}}=s_t smax=st(目标类别的分数最大),此时我们可以推导原始交叉熵的一个不等式:

log ⁡ ( ∑ i = 1 n e s i ) − s max = log ⁡ ( e s t + ∑ i ≠ t e s i ) − s max = log ⁡ ( e s max + ∑ i ≠ t e s i ) − log ⁡ ( e s max ) = log ⁡ ( e s max + ∑ i ≠ t e s i e s max ) = log ⁡ ( 1 + ∑ i ≠ t e s i − s max ) ≥ log ⁡ ( 1 + ( n − 1 ) e s min − s max ) (1) \begin{aligned} \log (\sum_{i=1}^n e^{s_i})-s_{\text{max}} &= \log (e^{s_t}+\sum_{i\neq t}e^{s_i})-s_{\text{max}}\\ &= \log (e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i})-\log (e^{s_{\text{max}}})\\ &= \log (\frac{e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i}}{e^{s_{\text{max}}}})\\ &= \log (1+ \sum_{i \neq t}e^{s_i - s_{\text{max}}})\\ & \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}}) \end{aligned}\tag{1} log(i=1nesi)smax=log(est+i=tesi)smax=log(esmax+i=tesi)log(esmax)=log(esmaxesmax+i=tesi)=log(1+i=tesismax)log(1+(n1)esminsmax)(1)

假设当前交叉熵值为 ε \varepsilon ε,那么有

ε ≥ log ⁡ ( 1 + ( n − 1 ) e s min − s max ) (2) \varepsilon \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}})\tag{2} εlog(1+(n1)esminsmax)(2)

解得

s max − s min ≥ log ⁡ ( n − 1 ) − log ⁡ ( e ε − 1 ) (3) s_{\text{max}} - s_{\text{min}} \ge \log (n - 1) - \log (e^{\varepsilon} - 1)\tag{3} smaxsminlog(n1)log(eε1)(3)

我们以 ε = ln ⁡ 2 = 0.69... \varepsilon = \ln2 = 0.69... ε=ln2=0.69...为例,这时候 log ⁡ ( e ε − 1 ) = 0 \log (e^{\varepsilon} - 1)=0 log(eε1)=0,那么 s max − s min ≥ log ⁡ ( n − 1 ) s_{\text{max}} - s_{\text{min}}\ge \log (n-1) smaxsminlog(n1)。也就是说,为了要loss降到0.69,那么最大的logit与最小的logit的差就必须大于 log ⁡ ( n − 1 ) \log (n-1) log(n1),当 n n n比较大时,对于分类问题来说这是一个没有必要的过大的间隔,因为我们只希望目标类的logit比所有非目标类都要大一点就行,但是并不一定需要大 log ⁡ ( n − 1 ) \log (n-1) log(n1)那么多,因此常规的交叉熵容易过度学习从而导致过拟合

稀疏的Sparsemax

前面说了这么多关于Softmax的内容,那么Sparse Softmax或者说Sparsemax是如何做到稀疏化分布的呢?原文内容大家可以直接去看论文,写的非常复杂,这里我给出苏剑林大佬设计的一个更简单的版本

Origin Sparse Softmax p i = e s i ∑ j = 1 n e s j p i = { e s i ∑ j ∈ Ω k e s j ,   i ∈ Ω k 0 ,   i ∉ Ω k CrossEntropy log ⁡ ( ∑ i = 1 n e s i ) − s t log ⁡ ( ∑ i ∈ Ω k e s i ) − s t \begin{array}{c|c|c} \hline & \text{Origin} & \text{Sparse} \\ \hline \text{Softmax} & p_i = \frac{e^{s_i}}{\sum\limits_{j=1}^{n} e^{s_j}} & p_i=\left\{\begin{aligned}&\frac{e^{s_i}}{\sum\limits_{j\in\Omega_k} e^{s_j}},\,i\in\Omega_k\\ &\quad 0,\,i\not\in\Omega_k\end{aligned}\right.\\ \hline \text{CrossEntropy} & \log\left(\sum\limits_{i=1}^n e^{s_i}\right) - s_t & \log\left(\sum\limits_{i\in\Omega_k} e^{s_i}\right) - s_t\\ \hline \end{array} SoftmaxCrossEntropyOriginpi=j=1nesjesilog(i=1nesi)stSparsepi=jΩkesjesi,iΩk0,iΩklog(iΩkesi)st

其中 Ω k \Omega_k Ωk是将 s 1 , s 2 , . . . , s n s_1,s_2,...,s_n s1,s2,...,sn从大到小排列后前 k k k个元素的下标集合。说白了,苏剑林大佬提出的Sparse Softmax就是在计算概率的时候,只保留前 k k k个,后面的直接置零, k k k是人为选择的超参数

代码

首先我根据苏剑林大佬的思路,给出一个简单版本的PyTorch代码

import torch
import torch.nn as nn

class Sparsemax(nn.Module):
    """Sparsemax loss"""

    def __init__(self, k_sparse=1):
        super(Sparsemax, self).__init__()
        self.k_sparse = k_sparse
        
    def forward(self, preds, labels):
        """
        Args:
            preds (torch.Tensor):  [batch_size, number_of_logits]
            labels (torch.Tensor): [batch_size] index, not ont-hot
        Returns:
            torch.Tensor
        """
        preds = preds.reshape(preds.size(0), -1) # [batch_size, -1]
        topk = preds.topk(self.k_sparse, dim=1)[0] # [batch_size, k_sparse]
        
        # log(sum(exp(topk)))
        pos_loss = torch.logsumexp(topk, dim=1)
        # s_t
        neg_loss = torch.gather(preds, 1, labels[:, None].expand(-1, preds.size(1)))[:, 0]
        
        return (pos_loss - neg_loss).sum()

再给出一个Github上找到的一个PyTorch原版代码

"""Sparsemax activation function.
Pytorch implementation of Sparsemax function from:
-- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
-- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
"""

import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Sparsemax(nn.Module):
    """Sparsemax function."""

    def __init__(self, dim=None):
        """Initialize sparsemax activation
        
        Args:
            dim (int, optional): The dimension over which to apply the sparsemax function.
        """
        super(Sparsemax, self).__init__()

        self.dim = -1 if dim is None else dim

    def forward(self, input):
        """Forward function.
        Args:
            input (torch.Tensor): Input tensor. First dimension should be the batch size
        Returns:
            torch.Tensor: [batch_size x number_of_logits] Output tensor
        """
        # Sparsemax currently only handles 2-dim tensors,
        # so we reshape to a convenient shape and reshape back after sparsemax
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)

        # Translate input by max for numerical stability
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)

        # Sort input in descending order.
        # (NOTE: Can be replaced with linear time selection method described here:
        # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        # Determine sparsity of projection
        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]

        # Compute threshold function
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        # Sparsemax
        self.output = torch.max(torch.zeros_like(input), input - taus)

        # Reshape back to original shape
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)

        return output

    def backward(self, grad_output):
        """Backward function."""
        dim = 1

        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        return self.grad_input
*补充

经过苏剑林大佬的许多实验发现,Sparse Softmax只适用于有预训练的场景,因为预训练模型已经训练得很充分了,因此finetune阶段要防止过拟合;但是如果从零训练一个模型,那么Sparse Softmax会造成性能下降,因为每次只有 k k k个类别被学习到,反而会存在学习不充分的情况(欠拟合)

References
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

数学家是我理想

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值