训练时的显存优化

   日期:2024-12-27    作者:67v7y 移动:http://ljhr2012.riyuangf.com/mobile/quote/73001.html

HuggingFace 的这篇文章总结了一系列节约显存的方法,非常全面。

训练时的显存优化

训练时显存占用的组成:

  • 模型参数
  • 优化器状态
  • 输入张量和其他临时张量
  • 激活值
  • 梯度
  • 通信缓冲

“激活值” 可能有点难理解。这是指像是 dropout 的 mask、LayerNorm 的 等,不是梯度但参加到梯度计算的张量。

除了用混合精度等方法降低整体显存占用,从 降低显存占用峰值 入手也是有效的。

通常的训练过程:计算 loss、反向传播、使用优化器 然后 清除梯度。


这就意味着,我们一次性计算了所有梯度,然后一并应用优化器参数更新。

如果能边算梯度边更新参数,就不需要用大量空间去存储梯度数据了。这就是融合 backward pass 和 optimizer step 的原理,能够有效降低显存占用峰值。

对于 PyTorch Lightning,需要借助 处理优化器逻辑:


若需要原生 PyTorch 实现,可以借助 :


本节参考:

  • https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/
  • https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html

AdamW 优化器最为常用,调参简单效果好。要说缺点,就是每个参数都需要额外 8 字节的显存。

Adafactor 优化器改变 Adam 的动量思路,将空间占用降低到了 4 字节。但实际使用中发现 Adafactor 可能会导致训练不稳定。

Bitsandbytes 库提供了一系列 8-bit 优化器。其实现的 AdamW8bit 只需占用 2 字节空间。

这个 issue 是包含各种优化器的 benchmark。可以看出,各优化器的训练损失都差不多。这么说,大胆使用 AdamW8bit 节省显存是个不错的主意。

对于参数少、激活多的网络(例如卷积网络),8-bit 优化器的效果不是很明显。

Bitsandbytes 库推荐在使用 8-bit 优化器训练 NLP 模型时,将 embedding 层换为 以保证训练稳定性。对于其他不稳定的参数,也可以使用 这个文档 提到的方法对那些参数单独使用 32-bit 优化器。

这个知乎问题下 提到 8-bit 优化器可能会让模型容易过拟合。注意一下。

PyTorch Lightning 对 Bitsandbytes 库有支持,可以自动替换用上 Bitsandbytes 的 8-bit 线性层。具体可看官方文档。

PyTorch 的优化器默认启用了一个叫 foreach 的 trick,能加快训练。但随之而来的是额外的优化器中间变量占用,会导致峰值显存占用变高。若要关闭 foreach,在定义优化器时传入参数 即可。


特别提示:本信息由相关用户自行提供,真实性未证实,仅供参考。请谨慎采用,风险自负。


举报收藏 0评论 0
0相关评论
相关最新动态
推荐最新动态
点击排行
{
网站首页  |  关于我们  |  联系方式  |  使用协议  |  隐私政策  |  版权隐私  |  网站地图  |  排名推广  |  广告服务  |  积分换礼  |  网站留言  |  RSS订阅  |  违规举报  |  鄂ICP备2020018471号