最近在做 LLM 训练相关的工作,反复被一个问题困住:明明模型参数只有 7B,为什么光是训练就要 100 多 GB 显存?排查下来,一半的”黑洞”藏在优化器状态里。这篇文章就是我梳理这块知识的记录,从 Adam 的基本原理出发,一路聊到分布式训练下的优化手段。


从 Adam 的两个动量说起

Adam 之所以好用,核心在于它对每个参数维护了两套独立的历史统计量——也就是常说的一阶动量 $m$ 和二阶动量 $v$。

一阶动量 $m$ 是梯度的指数移动平均,本质上是在平滑梯度信号,过滤掉单步的噪声,给优化方向一个”惯性”:

\[m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t\]

二阶动量 $v$ 是梯度平方的指数移动平均,记录的是历史梯度的”大小”:

\[v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2\]

最终的参数更新把两者结合起来:

\[\theta \leftarrow \theta - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \cdot \hat{m}_t\]

分母 $\sqrt{\hat{v}_t}$ 的作用很直观:历史梯度大的参数,分母大,学习率自动缩小;梯度一直很小的参数,学习率反而被放大。这就是所谓的”自适应学习率”。

为什么需要 Bias Correction

$m$ 和 $v$ 初始化为 0,训练早期会被零值拖累,估计严重偏低。Bias correction 通过除以 $(1 - \beta^t)$ 来修正这个偏差:

\[\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}\]

训练初期 $t$ 很小,$\beta_1^t$ 接近 1,修正系数很大,拉回了被压低的估计值。随着 $t$ 增大,修正项趋近于 1,自然消失。这个细节看起来不起眼,但不做修正会让前几千步的有效学习率远低于设定值,收敛明显变慢。


显存到底去哪了

理解了 $m$ 和 $v$ 之后,显存的账就很好算了。

在混合精度(AMP)训练下,标准做法是:前向/反向用 fp16 以节省计算和带宽,但优化器状态必须保持 fp32。原因很实际:$v$ 存的是梯度平方,值域可以跨越十几个数量级,fp16 的动态范围(约 $6 \times 10^{-8}$ 到 $6.5 \times 10^4$)根本压不住,容易溢出或下溢。$m$ 的累积更新也有类似问题,fp16 的精度不足会让小学习率场景下参数更新”消失”。

算下来,每个参数的显存占用是这样的:

存储项 精度 bytes/param
模型参数(fp16) fp16 2
梯度(fp16) fp16 2
参数主副本(fp32) fp32 4
一阶动量 $m$ fp32 4
二阶动量 $v$ fp32 4
合计   16 bytes

7B 模型:$7 \times 10^9 \times 16 \approx$ 112 GB,还没算激活值。而 inference 时只需要 fp16 权重,只要约 14 GB。训练和推理的显存差距接近 8 倍,绝大部分都被优化器状态和梯度贡献了。

这也解释了一件让很多人困惑的事:为什么断点续训必须保存优化器状态?$m$ 和 $v$ 是训练过程积累的”记忆”,丢掉它们重启,等于把历史统计清零,bias correction 会让前几千步的有效学习率极低,需要重新”预热”才能恢复,白白浪费算力。


Adam 的已知缺陷

用 Adam 做大规模训练久了,会踩到一些坑值得提前知道:

内存重是最直接的问题,上面已经说了。7B 模型光优化器状态就 56 GB,100B+ 的模型基本不可能在单机上跑。

对稀疏梯度适应差。Embedding 层的梯度极度稀疏,大量位置的梯度在大多数 step 都是 0,$v$ 的估计会长期维持在很小的值,等到真正有梯度时,分母过小导致步长爆炸。

超参数敏感。$\beta_1, \beta_2, \epsilon, \eta$ 四个参数的组合空间不小,$\epsilon$ 的选取对数值稳定性影响很大,不同任务和模型规模下最优组合差异明显。

为什么是 AdamW 而不是 Adam

现在主流 LLM 训练几乎清一色 AdamW,区别在于 Weight Decay 的处理方式。

原版 Adam 把 L2 正则加在梯度上,但梯度会被 $v$ 的自适应缩放处理,实际施加的正则化强度因参数而异,效果不稳定。AdamW 把 Weight Decay 从梯度中解耦出来,直接作用于参数:

\[\theta \leftarrow \theta(1 - \eta\lambda) - \frac{\eta \hat{m}}{\sqrt{\hat{v}}+\epsilon}\]

这样正则化强度与梯度幅度无关,对所有参数统一施加,大模型训练中这个差异对泛化性能的影响相当显著。


分布式训练下的优化器状态管理

单机放不下,就要考虑怎么切。ZeRO(Zero Redundancy Optimizer)是目前最主流的方案,核心思路是:既然每张卡都存一份完整的优化器状态是冗余的,为什么不切分?

ZeRO 分三个阶段,每个阶段切的东西不同:

Stage 1 只切优化器状态($m$ 和 $v$)。每张卡只维护 $1/N$ 的状态,更新时各卡计算自己负责的参数,用 reduce-scatter 同步梯度,用 all-gather 拿到完整参数。内存节省约 4 倍。

Stage 2 在 Stage 1 基础上再切梯度。梯度不再全量存储,只保留各卡负责部分的梯度。内存节省约 8 倍。

Stage 3 把参数本身也切了。每张卡只存 $1/N$ 的参数。内存节省约 16 倍,但代价是 forward 和 backward 都需要 all-gather 重建完整参数,通信量上升约 1.5 倍。网络带宽不够的集群上,Stage 3 的训练速度会明显下降,通常需要配合 prefetch 来掩盖通信延迟。

实践中,Stage 1/2 是性价比最高的选择,Stage 3 留给真的放不下的场景。


更省内存的替代优化器

如果 Adam 太重,有几个方向可以探索:

Adafactor 是 Google 提出的内存高效方案,不存完整的 $v$ 矩阵,而是用行向量和列向量的秩一分解来近似:$V \approx r \cdot c^T$。对于大型权重矩阵,内存从 $O(n)$ 降到 $O(\sqrt{n})$,优化器状态减少约一半。代价是近似引入了噪声,收敛稳定性略逊于 Adam,学习率调度也需要特殊处理。T5 家族就是用 Adafactor 训练的。

Lion 是 Google Brain 2023 年用进化搜索发现的优化器,只存一阶动量 $m$(省去 $v$),更新方向用 sign 函数:

\[\theta \leftarrow \theta - \eta \cdot \text{sign}(\beta_1 m + (1-\beta_1)g)\]

优化器状态减半至 4 bytes/param,计算量也更小,在图像和语言任务上效果与 AdamW 相当。使用时需要更小的学习率和更强的 Weight Decay,两者都与 Adam 有所不同,直接替换需要重新调参。


几个容易混淆的点

最后梳理几个经常被搞混的问题。

Gradient Checkpointing 和优化器状态有关系吗? 没有直接关系。Gradient Checkpointing 省的是激活值显存,通过在反向传播时重新计算部分激活值来换取内存,与优化器状态是两个独立的维度。所以Gradient Checkpointing 通常也叫做 Activation Checkpointing 或 Activation Recomputation(激活重算),两者可以同时使用,完整的省内存组合通常是 ZeRO Stage 2/3 + Gradient Checkpointing + bf16。

Inference 时优化器状态去哪了? 全部丢掉。推理只需要模型权重,$m$、$v$、fp32 主副本、梯度全部不需要。这也是为什么 GGUF、AWQ 这类推理格式只存权重(甚至量化权重),体积可以比训练 checkpoint 小一个数量级。

为什么现在很多训练用 bf16 而不是 fp16? bf16 与 fp32 有相同的指数位数(8 位),动态范围与 fp32 一致,只是精度(尾数位)低一些。fp16 动态范围小,大模型训练中梯度和激活值的数值范围经常超出 fp16 上限,需要 loss scaling 来稳定训练,bf16 则基本不需要这个麻烦。A100 和 H100 都对 bf16 有硬件加速支持,目前基本成了 LLM 训练的默认选择。


参考资料

以下是这些内容的原始出处,值得深读:

核心论文

工程向文章

博客 / 文档