导读
大模型时代能够充分利用GPU的显存是一项非常有必要的技能。本文将在仅考虑单卡的情况下为大家讲明白大模型的内存占用机制,相信对大家后续训练、使用大模型都非常有帮助。
1.告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用?
2.Lora相比于全参训练节省的显存是哪一部分?Qlora相比Lora呢?
3.混合精度训练的具体流程是怎么样的?
这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。
1.数据精度
想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少bit。我们都知道:
1 byte = 8 bits
1 KB = 1,024 bytes
1 MB = 1,024 KB
1 GB = 1,024 MB
由此可以明白,一个含有1G参数的模型,如果每一个参数都是32bit(4byte),那么直接加载模型就会占用4x1G的显存。
1.1常见的几种精度类型
个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:
各种精度的数据结构
可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。符号位都是1位(0表示正,1表示负),指数位影响浮点数范围,小数位影响精度。其中TF32并不是有32bit,只有19bit不要记错了。BF16指的是Brain Float 16,由Google Brain团队提出。
1.2 具体计算例子
我硕士话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据:我以BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:
- 题目:
随机生成的BF16精度数据
- 先给出具体计算公式:
- 然后step by step地分析(不是,怎么还对自己使用上Cot了)
符号位Sign = 1,代表是负数
指数位Exponent = 17,中间一坨是
小数位Mantissa = 3,后面那一坨是
- 最终结果
三个部分乘起来就是最终结果 -8.004646331359449e-34
- 注意事项
中间唯一需要注意的地方就是指数位是的全0和全1状态是特殊情况,不能用公式,如果想要深入了解可以看这个博客: 彻底搞懂float16与float32的计算方式-CSDN博客 如果感兴趣想更加深入了解如何从FP32转换为BF16的,可以看这个博主的讲解: 从一次面试搞懂 FP16、BF16、TF32、FP32
2.全参训练和推理的显存分析
OK了我们知道了数据精度对应存储的方式和大小, 相当于我们了解了工厂里不同规格的机器零件,但我们还需要了解整个生产线的运作流程,我们才能准确估算出整个工厂(也就是我们的模型训练过程)在运行时所需的资源(显存)。
那么就以目前最常见的混合精度训练方法作为参考,来看一看显存都去哪了。