HuggingFace 的这篇文章总结了一系列节约显存的方法,非常全面。
训练时显存占用的组成:
- 模型参数
- 优化器状态
- 输入张量和其他临时张量
- 激活值
- 梯度
- 通信缓冲
“激活值” 可能有点难理解。这是指像是 dropout 的 mask、LayerNorm 的 等,不是梯度但参加到梯度计算的张量。
除了用混合精度等方法降低整体显存占用,从 降低显存占用峰值 入手也是有效的。
通常的训练过程:计算 loss、反向传播、使用优化器 然后 清除梯度。
这就意味着,我们一次性计算了所有梯度,然后一并应用优化器参数更新。
如果能边算梯度边更新参数,就不需要用大量空间去存储梯度数据了。这就是融合 backward pass 和 optimizer step 的原理,能够有效降低显存占用峰值。
对于 PyTorch Lightning,需要借助 处理优化器逻辑:
若需要原生 PyTorch 实现,可以借助 :
本节参考:
- https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/
- https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
AdamW 优化器最为常用,调参简单效果好。要说缺点,就是每个参数都需要额外 8 字节的显存。
Adafactor 优化器改变 Adam 的动量思路,将空间占用降低到了 4 字节。但实际使用中发现 Adafactor 可能会导致训练不稳定。
Bitsandbytes 库提供了一系列 8-bit 优化器。其实现的 AdamW8bit 只需占用 2 字节空间。
这个 issue 是包含各种优化器的 benchmark。可以看出,各优化器的训练损失都差不多。这么说,大胆使用 AdamW8bit 节省显存是个不错的主意。
对于参数少、激活多的网络(例如卷积网络),8-bit 优化器的效果不是很明显。
Bitsandbytes 库推荐在使用 8-bit 优化器训练 NLP 模型时,将 embedding 层换为 以保证训练稳定性。对于其他不稳定的参数,也可以使用 这个文档 提到的方法对那些参数单独使用 32-bit 优化器。
这个知乎问题下 提到 8-bit 优化器可能会让模型容易过拟合。注意一下。
PyTorch Lightning 对 Bitsandbytes 库有支持,可以自动替换用上 Bitsandbytes 的 8-bit 线性层。具体可看官方文档。
PyTorch 的优化器默认启用了一个叫 foreach 的 trick,能加快训练。但随之而来的是额外的优化器中间变量占用,会导致峰值显存占用变高。若要关闭 foreach,在定义优化器时传入参数 即可。