(一)MHA 的维度变化原理

只要涉及到矩阵运算,就离不开线性代数的各种操作。如果不了解原理只是死记硬背,一开始看起来似乎会比较轻松,而实际结果却是记了又忘,反反复复。慢即是快,静下心找时间理了一下这里(死去的线性代数在不停追杀我 🙉

img-6ab0beafb41b41368d53aaab9500f5db

1. Attention 对于矩阵计算的优化

1.1 \(Q\)\(K\)\(V\) 矩阵融合成单个矩阵

对应的原始操作为:

# ...
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)

# ...
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)

常规的优化操作为:

# ...
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)

# ...
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

1.2 多头矩阵融合成单个矩阵

对应的原始操作为:

assert config.n_embd % config.num_heads == 0

# --- Head 1 计算 ---
# x: (B, L, D) -> q1: (B, L, head_dim)
q1, k1, v1 = self.q1(x), self.k1(x), self.v1(x)
# --- 每个头内部进行 Attention 计算 ---

# --- Head 2 计算 ---
q2, k2, v2 = self.q2(x), self.k2(x), self.v2(x)
# --- 每个头内部进行 Attention 计算 ---

# --- 合并 (Concat) ---
# 将两个 (B, L, 64) 拼成 (B, L, 128)
combined = torch.cat([head1_out, head2_out], dim=2)

output = self.w_o(combined)

常规的优化操作为:

assert config.n_embd % config.num_heads == 0

q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
# 变换过程: (B, L, D) -> (B, L, Heads, Head_Dim)
q = q.view(B, L, self.num_heads, self.head_dim)
k = k.view(B, L, self.num_heads, self.head_dim)
v = v.view(B, L, self.num_heads, self.head_dim)

# ...

# 变换: (B, L, Heads, Head_Dim) -> (B, Heads, L, Head_Dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# --- 融合的 Attention 计算 ---

context = context.transpose(1, 2)
context = context.contiguous().view(B, L, self.d_model)
output = self.w_o(context)

2. 底层数学逻辑

2.1 矩阵算子融合

矩阵融合其实非常简单,按照传统的做法,\(Q\)\(K\)\(V\) 分别对应三个独立的权重矩阵,且由于得到三个矩阵所用到的输入数据 \(x\) 都是同一个,因此如同 1.1 中操作的,每个输入都要与矩阵相乘:

  1. \(q = x W_q\)
  2. \(k = x W_k\)
  3. \(v = x W_v\)

此时,GPU 需要进行三次小规模的矩阵乘法,而融合算子的做法是将这三个独立矩阵横向(dim = 1)拼成一个大矩阵\(W_{qkv}\),此时列数变为原始的 3 倍,而 GPU 只需要进行一次乘法:

\[W_{qkv} = \begin{bmatrix} W_q & W_k & W_v \end{bmatrix}\]

\(x\) 经过这个 Linear 层(如 1.1 所示的 nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias))的时候,实际进行的运算为:

\[输出 = x W_{qkv} = x \begin{bmatrix} W_q & W_k & W_v \end{bmatrix} = \begin{bmatrix} xW_q & xW_k & xW_v \end{bmatrix} = \begin{bmatrix} q & k & v \end{bmatrix}\]

为了进一步得到独立的 \(q, k, v\) 矩阵,所以需要在最后的维度上进行切分:q, k, v = self.c_attn(x).split(self.n_embd, dim=2)。由于 \(x\) 的实际 shape 为 [B, L, D],所以列是下标为 2 的维度,而不是下标为 1 的维度。

算子通常对应计算图的一个节点,粒度非常细。算子融合对应的操作就是将多个小节点“坍缩”成一个大节点,目的不是减少工作量而是减少对显存的访问次数。现代 GPU 的计算能力极强,但是访存带宽/数据传输速度往往是瓶颈。

2.2 多头注意力的优化

无论是哪种实现方式,多头注意力开始的数据维度均为 [B, L, D],多头注意力计算结束的数据维度均为 [B, L, D]。换句话说,每个头只能且必定分到均匀的一部分 embedding 空间,在计算结束之后再拼接和融合。

注意,为了更纯粹地集中在线性变换的矩阵融合上,下文忽略了 Scaled Attention Score 的计算,不影响推导的正确性。

2.2.1 Type 1 - 独立小矩阵

对于独立小矩阵实现的多头注意力,如 1.2 所示,我们一开始就声明了 \(H=2\) 个独立的小矩阵 [D, d]\(d = D // head\_num\)),并独立计算完了 Attention,此时,我们手里有两个独立的输出张量,shape 都是 [L, d]。我们以每个头中 \(Q\) 的权重矩阵 \(W_q\) 为例,回顾 Attention 前的处理:

  • Head 1 输出 (\(O_1\)),\(XH_1\):
\[O_1 = XW_{q1} = \begin{bmatrix} x_{1} \\ x_{2} \end{bmatrix} W_{q1} = \begin{bmatrix} x_{1}W_{q1} \\ x_{2}W_{q1} \end{bmatrix} = \underbrace{\begin{bmatrix} o_{11} & o_{12} \\ o_{31} & o_{32} \end{bmatrix}}_{d = head\_dim} \quad\]

第 1 行是 Token 1 (\(x_1\))的结果,第 2 行是 Token 2 (\(x_2\))的结果;

  • Head 2 输出 (\(O_2\)),\(XH_2\):
\[O_2 = XW_{q2} = \begin{bmatrix} x_{1} \\ x_{2} \end{bmatrix} W_{q2} = \begin{bmatrix} x_{1}W_{q2} \\ x_{2}W_{q2} \end{bmatrix} = \begin{bmatrix} o_{21} & o_{22} \\ o_{41} & o_{42} \end{bmatrix}\]

接着,我们会使用拼接操作:torch.cat([O_1, O_2], dim=-1),由于行数完全一样,在数值分布上,这相当于\(O_2\) 的每一行,硬生生地衔接到 \(O_1\) 对应行的后面。拼接后的结果是一个 shape 为 [L, D] (\(D = 2d\)) 的矩阵:

\[O_{cat} = \left[\begin{array}{cc|cc} \overbrace{o_{11}, o_{12}}^{\text{Head 1}} & \overbrace{o_{21}, o_{22}}^{\text{Head 2}} \\ o_{31}, o_{32} & o_{41}, o_{42} \end{array}\right] \begin{matrix} \leftarrow \text{Token 1} \\ \leftarrow \text{Token 2} \end{matrix}\]

不难看出,\(O_{cat}\) 其实就等于:

\[\begin{aligned} O_{cat} &= \begin{bmatrix} XW_{q1} & XW_{q2} \end{bmatrix} \\ &= \begin{bmatrix} \begin{bmatrix} x_{1} \\ x_{2} \end{bmatrix}W_{q1} & \begin{bmatrix} x_{1} \\ x_{2} \end{bmatrix}W_{q2} \end{bmatrix} \\ &= \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} \begin{bmatrix} W_{q1} & W_{q2} \end{bmatrix} \\ &= X W_{q} \end{aligned}\]

2.2.2 Type 2 - 融合的矩阵

既然目标是得到 [L, D]\(O_{cat}\),我们进一步看看融合的矩阵是如何在不拆分小矩阵的情况下实现的。

q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

阶段一:Linear 计算

大矩阵 \(W_q\)\(W_k\)\(W_v\) 同理)的 shape 为 [D, D],而不是小矩阵的 [D, d],因为不需要切割,此时:

\[XQ = \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}Q = \begin{bmatrix} x_1W_q \\ x_2W_q \end{bmatrix}\]

让我们依然沿用 1.2 的设定,假设 \(L=2, H=2\)。在刚完成计算时,维度 \(L\) 的每个元素对应的结果依然是完整且水平展开的 embedding 向量(长度为 \(D\)):

\[\begin{aligned} &x_1Q = \begin{bmatrix} x_1W_{q1} & x_1W_{q2} \end{bmatrix} \\ &x_2Q = \begin{bmatrix} x_2W_{q1} & x_2W_{q2} \end{bmatrix} \end{aligned}\]

阶段二:View 升维

紧接着,执行q = q.view(B, L, self.num_heads, self.head_dim)。按照 head_num 在最后一个维度上将上一步得到的结果进行拆分(每个头都均匀地分享一部分 embedding 空间),此时形状的变化为:由 [L, D] 变为 [L, H, d]\(d = D // head\_num\))。在上面的矩阵乘法中,列的方向(维度 \(D\))由 \(Q\) 控制,因此相当于将 \(Q\) 进行了切分。

此时,维度 \(L\) 的每个元素依然有完整的 embedding 信息(\(D\)),只是内部从 \(H\) 的维度进行了进一步细分:

\[\begin{aligned} &L[0] = x_1Q = \begin{bmatrix} x_1W_{q1} \\ x_1W_{q2} \end{bmatrix},L[1] = x_2Q = \begin{bmatrix} x_2W_{q1} \\ x_2W_{q2} \end{bmatrix} \\ \\ &\text{其中:} \begin{bmatrix} x_1W_{q1}, & x_1W_{q2} \end{bmatrix} = \begin{bmatrix} x_1W_q \end{bmatrix}\end{aligned}\]

这里有非常关键的一个细节,也是我踩坑了很久的点,\(L[0]\) 此时的矩阵结构中,\(x_1W_{q1}\)\(x_1W_{q2}\) 不再是水平展开在同一“行”的相邻元素,而是变成了沿新增的 Head 维度垂直堆叠的不同的“行”。

升维时,很容易关注到的是轴数的增加,非常容易忽视随着特征维度的拆解,原本沿着单一维度水平展开的 \(D\) 个元素,会被折叠成由新增的 \(H\)\(d\) 共同描述的一个二维矩阵切面。这导致单独一行所能承载的特征大小从 \(D\) 缩小成了 \(d\) 个元素。

我们从 [H, d] 的角度再来观察 \(L[0]\)

\[\begin{align*}&L[0] = \underbrace{\begin{bmatrix} x_1W_{q1} \\ x_1W_{q2} \end{bmatrix}}_{d = 2}, \\ & L[0].H[0] = \begin{bmatrix} x_1 \cdot W_{q1}\end{bmatrix}, \text{我们通过升维划定的 Head 1,每个 Head 下元素为 d 个} \\ & L[0].H[1] = \begin{bmatrix} x_1 \cdot W_{q2}\end{bmatrix}, \text{我们通过升维划定的 Head 2,每个 Head 下元素为 d 个}\\ \end{align*}\]

\(L[1]\)同理:

\[\begin{align*}&L[1] = \underbrace{\begin{bmatrix} x_2W_{q1} \\ x_2W_{q2} \end{bmatrix}}_{d = 2}, \text{按照维度 L,此时是第二个 token 参与计算}\\ & L[1].H[0] = \begin{bmatrix} x_2 \cdot W_{q1}\end{bmatrix}, \text{$L_1$ 同样有 Head 1} \\ & L[1].H[1] = \begin{bmatrix} x_2 \cdot W_{q2}\end{bmatrix}, \text{$L_1$ 同样有 Head 2}\\ \end{align*}\]

阶段三:Transpose 换轴

执行 q = q.transpose(1, 2)进行转置,shape 将从 [L, H, d] 变化为 [H, L, d](省略 \(B\) 维度的变化,此时 \(H\) 即 dim = 1)。

这一步的含义是,不再按照 Token 的顺序(即维度 \(L\))进行第一优先级的分组,而是将 Head 的顺序作为第一优先级。所以 \(H[0]\) 此时就有了完整的 \(x_1\)\(x_2\) 的信息,相当于:第一顺位 head,第二顺位 Token,第三顺位 Embedding 值。\(H[0]\) 将“抽调”所有的 \(\text{Head 1}\) 的数据组成新的矩阵(\(H[1]\)同理):

\[H[0] = \underbrace{\begin{bmatrix} x_1W_{q1} \\ x_2W_{q1} \end{bmatrix}}_{d = 2} = XW_{q1}, \quad H[1] = \underbrace{\begin{bmatrix} x_1W_{q2} \\ x_2W_{q2} \end{bmatrix}}_{d = 2} = XW_{q2},\]

阶段四:多头独立计算(Batch Matmul)

至此,\(Q, K, V\) 三个矩阵都通过升维实现了多头的融合,shape 为 [H, L, d]

在实际进行 Attention 计算时,需要用到一个基础知识:对于 PyTorch,当我们对两个多维张量进行矩阵乘法时,无论维数多大,永远是最后两维进行乘法,前面的维度均会被当成 Batch。换句话说,矩阵 \(Q\)\(\text{Head 1}\) 只能看到 \(K\)\(V\)\(\text{Head 1}\),根本不会看到 \(\text{Head 2}\),前置的维度会被隔离,所以不用担心会混合。这里的计算等价于:

for h in range(H): # 取出第 h 个头的 Q 和 K^T (形状都是二维的矩阵) 
    Q_head = Q[h, :, :] # 形状 [L, d] 
    K_head_T = K_T[h, :, :] # 形状 [d, L] # 只在当前这个头内部进行标准的二维矩阵乘法 
    Output[h, :, :] = torch.matmul(Q_head, K_head_T)

Attention 的计算不影响我们对维度的分析,大矩阵并行算 Attention 时,我们的输出张量形状会发生 [H, L, d][H, L, L][H, L, d] 的变化,最终依然是 [H, L, d]。虽然我们前面讲 \(H\) 这个维度拆解成了 \(H[0]\)\(H[1]\),但数据依然在同一个矩阵中,只是更高维的数据我们不好直观用文本描述。

事实上,在内存里这个大矩阵是这样排的:先把 Head 1 的所有 Token 排完,再排 Head 2 的所有 Token(竖线分隔两个 head)。

\[[ \underbrace{o_{11}, o_{12}}_{T1}, \underbrace{o_{31}, o_{32}}_{T2} \quad | \quad \underbrace{o_{21}, o_{22}}_{T1}, \underbrace{o_{41}, o_{42}}_{T2} ]\]

阶段五:逆向重排与还原

最后实施转置与融合 context = context.transpose(1, 2).contiguous().view(B, L, self.d_model)。这是前两步的逆过程,相当于重新以维度 \(L\) 优先于维度 \(B\) 的优先级进行数据排列,接着再将最后两个维度的数据“捋直”(从 [H, d] 变为 [D]):

  • 第一步:Transpose

    • 执行 .transpose(1, 2) 后,形状变回 [L, H, d](忽略维度 B)。
    • 在内存中,它被强制重排(更准确地说,transpose 只改变逻辑视图,contiguous 强制内存重新分配),成了以 Token 作为第一优先级(竖线分隔两个 token):

      \[[ \underbrace{o_{11}, o_{12}}_{H1}, \underbrace{o_{21}, o_{22}}_{H2} \quad | \quad \underbrace{o_{31}, o_{32}}_{H1}, \underbrace{o_{41}, o_{42}}_{H2} ]\]

    • 此时如果我们在一张二维的纸面上描述,按照维度 L 的顺序它的逻辑视图为:

\[\begin{align*}&L[0] = \begin{bmatrix} o_{11} & o_{12} \\ o_{21} & o_{22} \end{bmatrix}, \text{第一个 token block,列方向是维度 d,即 embedding 空间} \\ &L[1] = \begin{bmatrix} o_{31} & o_{32} \\ o_{41} & o_{42} \end{bmatrix},\text{第二个 token block} \end{align*}\]
  • 第二步:View 捋直 (取消换行)

    • View 只是改变最后一个维度的“换行规则”。原本的最后一个维度是 \(d=2\),也就是每读 2 个数就换行。所以同一个 Token 内部,Head 1 和 Head 2 是分两行垂直排列的。
    • 现在的最后一个维度变成了 \(D=4\)(即 self.d_model,原始 embedding 的空间大小),于是,垂直折叠的 Head 1 和 Head 2,就像被一只手“捋直”成了一根水平的面条:

      \[\begin{bmatrix} o_{11} & o_{12} & o_{21} & o_{22} \\ o_{31} & o_{32} & o_{41} & o_{42} \end{bmatrix}\]

    • 这个“捋直”后的矩阵,和上面 2.2.1 我们用 torch.cat 硬拼接出来的矩阵 \(O_{cat}\) 结构完全等价。

3. 重要知识点汇总

  • 矩阵算子融合其实有一个基础且直白的原理。横向拼接矩阵时,如:\(W_{qkv} = \begin{bmatrix} W_q & W_k & W_v \end{bmatrix}\) ,矩阵乘法一直是用 左侧的“行”与右侧的“列” 相乘,因此无论拼接多少个新矩阵,后续拼接的矩阵都不会对前置的运算结果造成影响;
  • 两个高维张量的矩阵乘法只看最后两个维度,即:前面的维度完全相同(或满足广播机制)时才会进行最后两个维度的二维计算;
  • View 之后不仅维度增加了一个,最后一个维度的空间也缩小了。我们可以从右到左去观察 shape,然后模拟 torch 的操作,按照张量最后一个维度的数值,每次读取对应数值的元素后进行“换行”;

0 条评论

发表评论

您的邮箱地址不会被公开,仅用于通知回复。必填项已用 * 标注

© 2026 云朝野 · Powered by Wagtail