模型训练阶段,每张卡中显存占用可以分为两类:
模型状态
模型参数(fp16)、模型梯度(fp16)和 Adam 优化器状态(fp32 的模型参数备份,fp32 的 momentum 和 fp32 的 variance )。 假设模型参数量 x,则共需要 $ 2x + 2x + (4x + 4x + 4x) = 16x $ 字节存储。
小技巧
全量微调时,每增加 1B 参数,需要增加 16GB 的显存来存储模型状态
剩余状态
除了模型状态之外的显存占用,包括激活值、各种临时缓冲区以及无法使用的显存碎片。
ZeRO 策略只优化模型状态显存占用, 从 ZeRO-1 到 ZeRO-3 优化等级越来越高。
-
ZeRO-1 策略针对优化器状态进行分片,模型参数和梯度仍旧是每张卡保持一份,此时,每张卡的模型状态所需显存是 $4x + \frac{12x}{N}$( N 为 GPU 数目)
-
ZeRO-2 策略针对模型梯度进行分片,模型参数仍旧是每张卡保持一份,此时,每张卡的模型状态所需显存是 $2x + \frac{14x}{N}$ ( N 为 GPU 数目)
-
ZeRO-3 策略针对模型参数进行分片,此时每张卡的模型状态所需显存是 $\frac{16x}{N}$ ( N 为 GPU 数目)
小技巧
以 7B 模型 + 8 GPUs 全量微调为例:
- ZeRO-1 模式下,每张卡上模型状态显存占用约为
$$ 2*7 + 2*7 + \frac{4*7 + 4*7 + 4*7}{8} = 38.5 GB $$
- ZeRO-2 模式下,每张卡上模型状态显存占用约为
$$ 2*7 + \frac{2*7 + 4*7 + 4*7 + 4*7}{8} = 26.25 GB $$
- ZeRO-3 模式下,每张卡上模型状态显存占用约为
$$ \frac{2*7 + 2*7 + 4*7 + 4*7 + 4*7}{8} = 14 GB $$
小技巧
由于不同的优化方案不会影响模型训练结果,因此在不会导致 OOM 的前提下,建议使用优化等级较低的 ZeRO 策略。