(二)梯度消失和梯度爆炸

Transformer 的发展以及精简实现中有两个地方始终让我非常困惑,其一是现代架构抛弃 Post-LN 的原因;其二是初始化模型权重时针对每个 Module 最后的全连接层(Project Layer)的权重会做一个缩放。

1. 背景

1.1 LayerNorm

在传统的 Bert 以及 Attention Is All You Need 中,Transformer 采用的都是 PostLN 的结构,即形如:

x = LayerNorm(x + f(x))

然而更现代的做法是 Pre-LN,即形如:

x = x + f(LayerNorm(x))

我大概知道这里是因为梯度异常,Pre-LN 更为稳定。

1.2 模型初始化

在模型初始化的时候,虽然都是用符合高斯分布的随机数进行填充,例如,均值为 0,标准差为 0.02:

        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

但是每个 module 最终输出的全连接层会做一个特殊的处理,即让这些层的初始化权重离散程度更小,具体来说:

torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

我对这一层参数初始化在分布情况上的特殊处理感到困惑,无法理解为什么参数的初始化会影响后续的梯度反向传播。

1.3 深度学习的几种常见数据格式

目前深度学习通常会用到三种数据格式:torch.float(torch.float32), torch.half(torch.float16) 以及 torch.bfloat16。这三种数据格式的区别如下:

属性 FP32 (float32) FP16 (float16) BF16 (bfloat16)
位数 32 16 16
结构 1 位符号位 + 8 位指数位 + 23 位尾数位 1 位符号位 + 5 位指数位 + 10 位尾数位 1 位符号位 + 8 位指数位 + 7 位尾数位(牺牲精度换值域)
值域 \(10^{38}\) 约 65504 \(10^{38}\)
精度
应用场景 高精度科研计算 兼容性/应用最广 训练稳定性极高,无溢出困扰

不加任何类型声明时,Python 声明一个浮点数默认类型为 float64(double),而由于 64 位浮点数在深度学习中没有太大必要,因此在使用 torch 来声明一个浮点型 tensor 时,默认类型为 float32。除非修改了默认配置,否则torch.float 只是 float32的别名。

为了节省显存和加速矩阵乘法,几乎不使用 FP32,而是使用 FP16 或者 BF16。FP16 最大值只有 65504,非常容易超出以至于得到 NaN。

2. Transformer 的反向传播与梯度消失

2.1 反向传播

反向传播的目的是为了更新权重,而非更新输入数据。Transformer 的一层网络的经典结构为:\(x_{n+1} = x_n + f(x_n, W_n)\)。假设最终损失函数为 \(L\),我们要计算损失 \(L\) 对第 \(n\) 层节点权重 \(W_n\) 的梯度。于是:

\[\begin{aligned}\frac{\partial L}{\partial W_n} &= \frac{\partial L}{\partial x_{n+1}} \cdot \frac{\partial x_{n+1}}{\partial W_n} \\ &= \frac{\partial L}{\partial x_{n+1}} \cdot \frac{\partial (x_n + f(x_n, W_n))}{\partial W_n} \\ &= \frac{\partial L}{\partial x_{n+1}} \cdot \frac{\partial f(x_n, W_n)}{\partial W_n} \\ &= \frac{\partial L}{\partial x_{n+1}} \cdot x_n \end{aligned}\]

为了说明链式法则的连乘效应,此处省略了矩阵求导的转置操作,仅作标量形式的示意

其中,\(L\)\(W_n\) 通过 \(x_{n+1}\) 建立间接联系, 第 \(n+1\) 层的权重 \(W_{n+1}\) 与 第 \(n\) 层的权重 \(W_n\) 并没有关系,所以不会体现到偏导数的连乘上。另外,虽然 \(f\) 可能是 Attention 或者 FFN,但是从宏观上,是权重矩阵与输入的相乘,即 \(x_n \cdot W_n\),所以 \(f'\) 最终可以简化为 \(x_n\)

注意这里出现了 \(\frac{\partial L}{\partial x_{n+1}}\);那么进一步让我们讨论损失 \(L\) 对第 \(n\) 层节点 \(x_n\) 的梯度。我们将这个梯度记为 \(\delta_n\),于是:

\[\delta_n = \frac{\partial L}{\partial x_n}\]

我们以下标增大的方向为网络更靠近输出层的方向,根据链式法则:

\[\delta_n = \frac{\partial L}{\partial x_n} = \frac{\partial L}{\partial x_{n+1}}\frac{\partial x_{n+1}}{\partial x_{n}} = \delta_{n+1}\frac{\partial x_{n+1}}{\partial x_n}\]

从形式上可以看出第 n 层的梯度(\(\delta_n\),损失对这一层输入数据 \(x_n\) 的偏导)就等于更深一层的梯度乘以当前层的偏导。

  • 稍微岔开一下话题,让我们回到对 \(W_n\) 的推导:
\[\begin{aligned}\frac{\partial L}{\partial W_n} &= \frac{\partial L}{\partial x_{n+1}} \cdot x_n \\ &= \delta_{n+1} \cdot x_n \\ &= (\delta_{n+2}\frac{\partial x_{n+2}}{\partial x_{n+1}})\cdot x_n\end{aligned}\]
  • 注意,\(\frac{\partial L}{\partial W_n}\) 中的 \(x_n\) 只对这一层权重 \(W_n\) 的更新产生影响,不会随着反向传播而传递,即从 \(x_n\) 的角度考虑,各个层的权重更新是平行的(\(W_{n+1}\)\(W_n\) 无关),虽然正比于 \(x_n\),但更浅层的权重更新不会出现 \(x_n\) 的连乘。

我们把前面得到的这个公式 \(\delta_n = \delta_{n+1}\frac{\partial x_{n+1}}{\partial x_n}\) 当做一个通项,接着来观察其中的 \(\frac{\partial x_{n+1}}{\partial x_n}\) 部分。

2.2 Post-LN

当我们考虑 LayerNorm 时,对于 Post-LN:\(x_{n+1} = LayerNorm(x_n + f(x_n))\)。进一步:

\[\begin{aligned}\frac{\partial{x_{n+1}}}{\partial x_n} &= \frac{\partial{LayerNorm(x_n + f(x_n))}}{\partial x_n} \\ &= ...\qquad\text{(这里有一些非常复杂的计算,但是根据偏导的乘积法则最终能拆解出一项$\frac{1}{\sigma}$)}\\&\approx\frac{1}{\sigma}(I + f'(x_n))\end{aligned}\]

代入通项:

\[\begin{aligned}\delta_n &= \delta_{n+1}\frac{\partial x_{n+1}}{\partial x_n} \\ &=\delta_{n+1}\cdot\frac{1}{\sigma_n}(I + f'(x_n)) \\ &=\frac{1}{\sigma_n}[\delta_{n+1}(I + f'(x_n))]\end{aligned}\]

不难看出,这个 \(\frac{1}{\sigma}\) 会直接作用在每一层的梯度计算上,且在远离输出层的方向上,这个\(\frac{1}{\sigma}\)会不断连乘,当层数较多时,将使得最终回传到靠近输入层的梯度消失!

2.3 Pre-LN

而对于 Pre-LN:\(x_{n+1} = x_n + f(LayerNorm(x_n))\)。进一步:

\[\begin{aligned}\frac{\partial{x_{n+1}}}{\partial x_n} &= \frac{\partial{(x_n + f(LayerNorm(x_n)))}}{\partial x_n} && \text{注意: f 是嵌套函数,所以后续使用“链式”法则}\\ \\&= I + f'(\cdot) \cdot LayerNorm'(\cdot) \end{aligned}\]

代入通项:

\[\begin{aligned}\delta_n &= \delta_{n+1}\frac{\partial x_{n+1}}{\partial x_n} \\ &= \delta_{n+1}\cdot(I + f'(\cdot)LayerNorm'(\cdot)) \\ &= \delta_{n+1} + \delta_{n+1}f'(\cdot)LayerNorm'(\cdot)\end{aligned}\]

其实 \(LayerNorm'(x)\) 和 Post-LN 一样依然会产生一个连乘的 \(\frac{1}{\sigma}\)(等号右边的残差分支上),但由于单位矩阵 \(I\) 的存在,加号左边有一项完整的\(\delta_{n+1}\),这使得更深一层的梯度一定会完整传递到靠近输入层的梯度,而不会因 \(\frac{1}{\sigma}\) 的影响而消失!

这也就是很多文章提到的 梯度回传的主干道/高速公路,通过这种方式,即便是极其深的网络也能够收敛。

3. Transformer 的前向传播与梯度爆炸

3.1 前向传播时的数据方差

Transformer 是一种残差网络,这意味着无论是 Attention 部分还是 FFN 部分,计算过程整体上都是一样(假设没有任何归一化处理)。第 n 层的前向计算公式为:

\[x_{n+1} = x_n + f_n(x_n)\]

其中 f(x) 是 Attention 或者 FFN 操作,而这两个操作本质上都是矩阵乘法,也就是 \(f(x) = Wx\)

对于 Transformer 中的一层,输出向量 \(y\),输入向量 \(x\),它们之间的关系为:

\[\begin{aligned}&y = Wx\\ \\ &\text{对于 y 的具体节点 } y_i\text{,}y_i = w_1 x_1 + w_2 x_2 + \dots + w_d x_d \end{aligned}\]

深度学习初始化的基本假设是 \(w\)\(x\) 相互独立,且均值都为 0。所以两个向量的乘积的方差等于方差的乘积:

\[D(w_k x_k) = D(w_k) \cdot D(x_k)\]

\(y_i\) 等式右侧是 \(d\) 个乘积的累加,因此进一步:

\[D(y_i) = d\cdot D(w_k) \cdot D(x_k)\]

代入原来的前向公式:

\[\begin{aligned}D(x_{n+1}) &= D(x_n) + D(f_n(x_n)) \\ &=D(x_n) + d\cdot D(w_k) \cdot D(x_k)\end{aligned}\]

由于离群值并不会轻易改变整体数据的方差,因此当方差变大时意味着整体数据就是在往更大的方向变化。我们可以从这个公式看出两个问题:

  • 问题一:由于常系数 \(d\) 的存在,在什么都不做的情况下,仅仅经过一层网络,方差便会变为原来的 \(d\) 倍。随着网络的深入,方差会迅速增长;
  • 问题二:更深一层的方差与当前层的方差有关,呈累加关系,随着层数的增加,这里的方差会不断累加。

在神经网络中,生成 \(y_i\) 的所有权重 \(w_{ik}\) 都是用相同的分布随机采样出来的,又因为这些权重相乘的输入都是相同的,所以在统计学意义上,这 \(d\) 个输出节点是独立同分布的。因为每一层的每一个节点 \(y_i\) 服从相同的分布,所以这一层整体数据的方差 \(D(y)\),在数学期望上 \(E(D(y))\) 就等于其内部任意一个单节点 \(y_i\) 的方差。或者简单来说,这一层整体数据背后的理论方差就等于单个节点的方差。当数据量足够大的时候,样本就等于整体,无数次样本方差后得到的平均数(数学期望)就等于整体方差。所以这里对单节点 \(y_i\) 的讨论不影响整体的正确性。

对于问题一,如果我们不考虑常系数 \(d\),假设某一层的输入数据方差为 \(D_i\),我们也需要控制这一层的输出数据方差为 \(1.0 \times D_i\),这里强调 \(1.0\) 是为了说明无论是略微大一点点达到 \(1.1\) 或者略微小一点点达到 \(0.9\),根据我们上面得到的连乘关系,这个变化都会产生指数级的效果,即便是较浅的 \(99\) 层,根据指数关系(\(0.9^{99}=0.00002951266543\), \(1.1^{99}=12,527.8293998384\))都会引起方差急剧的变化。

3.2 前向传播对梯度影响

我们在 2.1 节提到,梯度反向传播的目的是为了更新权重参数,有公式:

\[\begin{aligned}\frac{\partial L}{\partial W_n} &= \frac{\partial L}{\partial x_{n+1}} \cdot x_n \\ &= \delta_{n+1} \cdot x_n \end{aligned}\]

根据第 2 章的讨论,我们知道 \(\delta_{n+1}\) 中由于包含连乘的关系,所以需要加以控制。现在让我们讨论乘法的另一项 \(x_n\),这一项不会跟随反向传播而传递到浅层(不会导致连乘关系),而是只参与当前这一层的权重更新计算。

真正的问题在于,随着层数的加深,根据 3.1 节的讨论,如果不对方差加以控制,由于常系数 \(d\) 的存在以及累加的影响,\(x_n\) 是趋向于越来越大的。这便会导致靠近输出层的梯度在开始反向传播时爆炸。

3.3 Post-LN

对于 Post-LN:\(x_{n+1} = LayerNorm(x_n + f(x))\)。 最后的 LayerNorm 完美地解决了这个问题,每次经过一个 Module,LayerNorm 都会对最终输出的方差加以控制。

虽然 Post-LN 对方差进行了控制,但是相比靠近输入层的梯度(受到 \(\frac{1}{\sigma}\) 的影响),梯度在输出层依然很大,出现“头重脚轻”的问题,即靠近输出层的梯度会远大于靠近输入层的梯度。如果不用 Warmup 压制学习率会导致模型输出层的参数剧烈变化,导致模型发散,所以早期 Transformer 非常依赖 Warmup 来保证模型的稳定。

3.4 Pre-LN

而对于 Pre-LN:\(x_{n+1} = x_n + f(LayerNorm(x))\)。这里有一个重要的差点把我绕进去的误区,看起来和 Post-LN 类似,不过是一个在 Module 结尾增加 LayerNorm,另一个在 Module 开始增加 LayerNorm,那么似乎应该都能解决这个问题?

让我们展开来看,Pre-LN 中的 LayerNorm 发生在非常早期,我们依然假设 \(f\) 是一个纯粹的矩阵乘法:\(f(x_n) = W\cdot LN(x_n)\),所以公式变为:

\[x_{n+1} = x_n + f(x_n) = x_n + W \cdot \text{LN}(x_n)\]

方差:由于 LayerNorm 会对方差进行控制,继续 3.1 的推导,所以进一步有:

\[\begin{aligned} D(x_{n+1}) &= D(x_n) + D(W \cdot \text{LN}) \\ &= D(x_n) + d\cdot D(w_k) \cdot D(\text{LN}) \\ &= D(x_n) + d\cdot D(w_k) \end{aligned}\]
  • 既没有影响 \(d\) 的存在性,也没有影响残差连接带来的累加效应。但是此时加号右侧已经没有了 \(D(x_n)\),即 LayerNorm 消除了连乘的影响,不再有指数爆炸,但是依然存在较大的线性增长导致可能溢出。

梯度:同样是由于 LayerNorm 的控制,让我们回顾 2.1 处得到的 公式:\(\frac{\partial L}{\partial W_n} = \delta_{n+1} \cdot x_n\) ,当时的假设是不包含任何归一化,而对于 Pre-LN,最后会变为:

\[\begin{aligned}\frac{\partial L}{\partial W_n} &= \frac{\partial L}{\partial y_n} \cdot \frac{\partial y_n}{\partial W_n} \\ &= \frac{\partial L}{\partial y_n} \cdot \frac{\partial{(W_n \cdot \text{LN}(x_n))}}{\partial W_n} \\ &= \delta_{n+1} \cdot \text{LN}(x_n) \end{aligned}\]
  • 我们前面 2.3 得到 Pre-LN 的 \(\delta_{n+1}\) 是由两部分组成,没有连乘的影响,并且最终回传的梯度正比于当前层的 “\(x_n\)”,而没有 “\(x_n\)” 的连乘
  • 考虑到 LayerNorm 之后,实际上 “\(x_n\)” 是经过 LayerNorm 控制的 \(LN(x_n)\),所以我们实际上需要控制乘号左侧的 \(\delta_{n+1}\) 的影响

业界针对方差的线性增长额外做了两项努力:

1. 消除\(d\)的影响
既然更深一层的数据方差与当前层的数据方差依然具有 \(d\cdot D(w_k)\) 的关系,那么直接令每一层的权重参数方差为 \(\frac{1}{d}\) 就好了,这样 \(d\cdot D(w_k) = 1\),在 Pre-LN 的情况下 \(D(x_{n+1}) = D(x_n) + 1\),此时大幅降低了 \(d\) 带来的与维度相关的线性增长影响。torch 提供了对应的 Xavier 参数初始化方法来便捷地达到这一效果。

在 GPT 系列中,OpenAI 的工程师做了一个工程上的简化:他们发现对于常用的维度大小,直接把所有权重的初始标准差(std)定死 0.02,就能起到和 Xavier 初始化几乎一样的“保持方差不变”的效果。这也是为什么 nanoGPT 初始化时会使用 0.02 作为标准差。

其实 0.02 恰好比常见的维度 \(d\) (768、1024、1280、2048 等)下 \(\frac{1}{\sqrt{d}}\) 的结果略小一点,属于一个经验折中值。

2. 消除残差累加的影响
残差连接是重要的一环,也就是前面说的“主干道”的 \(x_n\) 部分,当层数变大时,虽然经过 1 的努力,加号右侧已经变为常数了,但依然会带来层数等同的变化量级。即,对于第 n 层:\(D(x_n)=D(x_0)+n\)。为了消除这种影响,使得最终的方差尽可能保持不变,那么解法也很直接,将每一层累加的方差从 1 降低为 \(\frac{1}{n}\),这样前向传播到最后一层的时候,方差变化也只有常数级别(理想情况下得到 1)。

注意,这里的缩放使用的 \(n\) 并不是跟随当前层数变化的值,而是整个网络的总层数。nanoGPT 中之所以选择降低 \(\frac{1}{2N}\)(torch 对应函数指定的是标准差,所以实现上会开方),是因为声明了 n 个 Block,但是每个 Block 的 forward 中包含两次残差连接。

class Block(nn.Module):
    # ...
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        # ...
        self.transformer = nn.ModuleDict(dict(
            # ...
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            # ...
        ))

    # ...

4. 小结

  • 前向传播的方差也会对梯度的传播造成影响,整体的数据变化一定来自于具体数据的变化
  • 残差天然会带来累加,层数过深时就需要注意线性增长的影响
  • 反向传播的目的是更新权重参数 \(w\),而不是输入输出 \(x\),但是根据偏导的链式法则,权重参数的梯度与 \(x\) 有关
  • 梯度包含了本轮学习的结果,要保证层数较大时也能够平稳传递到更早期的层

0 条评论

发表评论

您的邮箱地址不会被公开,仅用于通知回复。必填项已用 * 标注

© 2026 云朝野 · Powered by Wagtail