又搬了个蒸馏相关~~ 神经网络中的蒸馏技术
“模型集成是一个相当有保证的方法,可以获得2%的准确性。“ —— Andrej Karpathy我绝对同意!然而,部署重量级模型的集成在许多情况下并不总是可行的。有时,你的单个模型可能太大(例如GPT-3),以至于通常不可能将其部署到资源受限的环境中。这就是为什么我们一直在研究一些模型优化方法 ——量化和剪枝。在这个报告中,我们将讨论一个非常厉害的模型优化技术 —— 知识蒸馏。
Softmax告诉了我们什么?
当处理一个分类问题时,使用softmax作为神经网络的最后一个激活单元是非常典型的用法。这是为什么呢?因为softmax函数接受一组logit为输入并输出离散类别上的概率分布。比如,手写数字识别中,神经网络可能有较高的置信度认为图像为1。不过,也有轻微的可能性认为图像为7。如果我们只处理像[1,0]这样的独热编码标签(其中1和0分别是图像为1和7的概率),那么这些信息就无法获得。人类已经很好地利用了这种相对关系。更多的例子包括,长得很像猫的狗,棕红色的,猫一样的老虎等等。正如Hinton等人所认为的
一辆宝马被误认为是一辆垃圾车的可能性很小,但被误认为是一个胡萝卜的可能性仍然要高很多倍。
这些知识可以帮助我们在各种情况下进行极好的概括。这个思考过程帮助我们更深入地了解我们的模型对输入数据的想法。它应该与我们考虑输入数据的方式一致。所以,现在该做什么?一个迫在眉睫的问题可能会突然出现在我们的脑海中 —— 我们在神经网络中使用这些知识的最佳方式是什么?让我们在下一节中找出答案。
使用Softmax的信息来教学 —— 知识蒸馏
softmax信息比独热编码标签更有用。在这个阶段,我们可以得到:
-
训练数据
-
训练好的神经网络在测试数据上表现良好
我们现在感兴趣的是使用我们训练过的网络产生的输出概率。考虑教人去认识MNIST数据集的英文数字。你的学生可能会问 —— 那个看起来像7吗?如果是这样的话,这绝对是个好消息,因为你的学生,肯定知道1和7是什么样子。作为一名教师,你能够把你的数字知识传授给你的学生。这种想法也有可能扩展到神经网络。
知识蒸馏的高层机制
所以,这是一个高层次的方法:
-
训练一个在数据集上表现良好神经网络。这个网络就是“教师”模型。
-
使用教师模型在相同的数据集上训练一个学生模型。这里的问题是,学生模型的大小应该比老师的小得多。
本工作流程简要阐述了知识蒸馏的思想。为什么要小? 这不是我们想要的吗?将一个轻量级模型部署到生产环境中,从而达到足够的性能。
用图像分类的例子来学习
对于一个图像分类的例子,我们可以扩展前面的高层思想:
-
训练一个在图像数据集上表现良好的教师模型。在这里,交叉熵损失将根据数据集中的真实标签计算。
-
在相同的数据集上训练一个较小的学生模型,但是使用来自教师模型(softmax输出)的预测作为ground-truth标签。这些softmax输出称为软标签。稍后会有更详细的介绍。
我们为什么要用软标签来训练学生模型?
请记住,在容量方面,我们的学生模型比教师模型要小。因此,如果你的数据集足够复杂,那么较小的student模型可能不太适合捕捉训练目标所需的隐藏表示。我们在软标签上训练学生模型来弥补这一点,它提供了比独热编码标签更有意义的信息。在某种意义上,我们通过暴露一些训练数据集来训练学生模型来模仿教师模型的输出。
希望这能让你们对知识蒸馏有一个直观的理解。在下一节中,我们将更详细地了解学生模型的训练机制。
知识蒸馏中的损失函数
为了训练学生模型,我们仍然可以使用教师模型的软标签以及学生模型的预测来计算常规交叉熵损失。学生模型很有可能对许多输入数据点都有信心,并且它会预测出像下面这样的概率分布:
扩展Softmax
这些弱概率的问题是,它们没有捕捉到学生模型有效学习所需的信息。例如,如果概率分布像[0.99, 0.01]
,几乎不可能传递图像具有数字7的特征的知识。
Hinton等人解决这个问题的方法是,在将原始logits传递给softmax之前,将教师模型的原始logits按一定的温度进行缩放。这样,就会在可用的类标签中得到更广泛的分布。然后用同样的温度用于训练学生模型。
我们可以把学生模型的修正损失函数写成这个方程的形式:
其中,pi是教师模型得到软概率分布,si的表达式为:
def get_kd_loss(student_logits, teacher_logits,
true_labels, temperature,
alpha, beta):
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
kd_loss = tf.keras.losses.categorical_crossentropy(
teacher_probs, student_logits / temperature,
from_logits=True)
return kd_loss