KV Cache 与 MQA/GQA:从推理优化看注意力机制的工程化演进

KV Cache 与 MQA/GQA:从推理优化看注意力机制的工程化演进

本文将从自回归推理的工程需求出发,系统阐述 KV Cache 的本质、设计原理及其如何自然引出多查询注意力(MQA)与分组查询注意力(GQA)的优化路径。

一、自回归推理的瓶颈与 KV Cache 的诞生

在 Transformer 的自注意力机制中,第 (t) 个位置的输出需要与历史所有位置进行交互:

[text{Attn}(t) = text{softmax}left(frac{Q_t K_{1:t}^{top}}{sqrt{d_h}}right) V_{1:t} ]

其中 (Q_t = x_t W_Q)(K_t = x_t W_K)(V_t = x_t W_V)

训练与推理在计算模式上存在根本差异。训练阶段采用并行计算:整个序列一次性输入,所有位置的注意力同步完成,无需保留中间状态。而自回归推理本质上是串行过程,每步仅生成一个新 token。若不加优化,第 (t) 步需要重新计算前 (t-1) 个位置的 (K)(V) 矩阵,导致大量重复计算。

KV Cache 的核心思想是利用注意力计算的增量特性:对于已生成的历史 token,其 (K)(V) 向量在后续步骤中保持不变。因此推理时在每一层维护一个缓存结构,逐步追加新 token 的键值对。第 (t) 步的计算流程简化为:

  1. 仅为当前 token 计算 (Q_t)(K_t)(V_t)
  2. (K_t)(V_t) 追加到缓存的末尾
  3. (Q_t) 对完整的 (K_{1:t})(V_{1:t}) 执行一次注意力操作

这一机制将时间复杂度从每步 (mathcal{O}(t cdot D)) 的键值计算降至 (mathcal{O}(D)),显著减少计算量和内存带宽消耗。
KV Cache 与 MQA/GQA:从推理优化看注意力机制的工程化演进
KV Cache 与 MQA/GQA:从推理优化看注意力机制的工程化演进


二、KV Cache 的内存结构与规模分析

设模型具有 (L) 层、隐藏维度 (D)、注意力头数 (h)(单头维度 (d_h = D/h)),推理批量大小为 (B),当前序列长度为 (T)。在标准的多头注意力(MHA)中,每一层缓存的形状为:

[K: [B, h, T, d_h], quad V: [B, h, T, d_h] ]

全模型的 KV Cache 总量为 (2 cdot B cdot L cdot h cdot T cdot d_h = 2BLTD) 个浮点数。以 LLaMA-7B 为例((L=32)(D=4096),单精度浮点),生成长度 (T=2048) 时,单个样本的 KV Cache 约占用 1GB 显存。这一规模随序列长度和批量大小线性增长,成为长文本生成和高并发推理的主要瓶颈。

需要明确的是,KV Cache 存储的是经过线性投影后的连续向量表示,其规模与词表大小无关。缓存的增长完全由历史生成步数 (T) 驱动,这也是长上下文场景下显存压力剧增的根源。


三、从 MHA 到 GQA/MQA:注意力头的冗余与共享

标准 MHA 为每个查询头配备独立的键值头,这在训练阶段有助于学习多样化的注意力模式。然而在推理场景中,这种设计带来两个问题:

显存与带宽压力:KV Cache 规模与头数 (h) 成正比,限制了批量大小和上下文长度。

键值冗余:实证研究表明,不同注意力头学到的键值表示存在高度相关性,维持 (h) 份独立副本的必要性存疑。

多查询注意力(MQA)通过极端共享来解决这一问题:所有查询头共用单一的键值头((g=1))。此时缓存形状变为 ([B, 1, T, d_h]),规模直接缩减至原来的 (1/h)。分组查询注意力(GQA)采用折中策略,将 (h) 个查询头分为 (g) 组((1 < g < h)),每组共享一对键值头。缓存规模降为 (2BLTD cdot (g/h))

这一演进的数学基础在于键值投影矩阵的参数压缩。MHA 具有 (h) 个独立的 (W_K^{(i)})(W_V^{(i)})(i=1,ldots,h)),而 GQA 仅保留 (g) 个键值投影,每个服务 (h/g) 个查询头。通过适当的训练策略(如从 MHA 检查点初始化后短暂微调),模型可以在保持大部分性能的前提下,大幅降低推理成本。


四、工程实践中的关键考量

量化与压缩:将 KV Cache 从 FP16 量化至 INT8 或 FP8 可进一步减半显存占用,配合动态缩放技术,精度损失通常在 1% 以内。

分页管理:借鉴操作系统的虚拟内存思想,vLLM 等框架将 KV Cache 切分为固定大小的块(如 16 个 token),动态分配物理显存,显著提升显存利用率和批处理吞吐。

卸载与重算:对超长上下文,可将早期 token 的 KV 缓存卸载至 CPU 内存,或在访问时按需重算。前者适用于内存充裕场景,后者在内存受限但计算资源充足时更优。

架构选择指南:通用高吞吐场景优先采用 GQA((g in {2, 4, 8})),平衡质量与效率;边缘部署或极限长上下文场景考虑 MQA;训练主导的应用保持 MHA 或较大的 (g) 值以充分利用模型容量。


五、面试高频问题解析

Q1:为什么只缓存 K 和 V,不缓存 Q?

查询向量 (Q_t) 仅在第 (t) 步与历史键值交互时使用,后续步骤不再需要。而 (K_t)(V_t) 需要被未来所有步骤访问,必须持久化保存。这是注意力计算的非对称性导致的自然结果。

Q2:KV Cache 如何影响模型的并行性?

训练时的数据并行和张量并行不受影响,KV Cache 仅在推理阶段激活。在张量并行部署中,每个 GPU 存储 (h/N) 个头的键值缓存((N) 为并行度),通信仅发生在最终的输出聚合阶段,缓存本身不跨卡传输。

Q3:GQA 的分组数 (g) 如何确定?

需要在实验中平衡性能与效率。常见实践是设定 (g)(h) 的约数(如 (h=32) 时取 (g in {4, 8})),使每组查询头数相等。元研究表明,(g geq 4) 时性能下降通常小于 2%,而显存节省达 75% 以上。

Q4:KV Cache 与 FlashAttention 是什么关系?

两者解决不同层面的问题。FlashAttention 通过分块计算和 IO 优化降低显存峰值和访问次数,在训练和推理中均有效。KV Cache 专注于推理阶段的增量计算优化。实际部署中两者通常结合使用,FlashAttention 负责单步注意力的高效执行,KV Cache 负责跨步的状态管理。

Q5:如何处理动态批处理中不同样本的序列长度差异?

采用 padding + mask 机制:将批内样本对齐到最大长度,通过 attention mask 屏蔽无效位置。现代框架(如 vLLM)进一步使用 PagedAttention,为每个样本独立分配块,避免 padding 浪费显存。批处理调度器会动态组合长度相近的请求,最大化硬件利用率。


六、总结与展望

KV Cache 是 Transformer 自回归推理的关键优化技术,通过缓存历史键值对将重复计算转化为内存查表,显著降低推理延迟。其内存开销随序列长度和头数线性增长,自然催生了 MQA 和 GQA 等共享机制。这一演进本质上是在模型表达能力与工程效率之间寻找最优平衡点。

随着上下文窗口扩展至百万 token 级别,KV Cache 的优化仍是活跃研究领域。未来的方向包括基于重要性的选择性缓存、低秩分解压缩、以及与检索增强生成(RAG)的深度融合。理解 KV Cache 的设计原理,是掌握大语言模型推理系统的必经之路。

发表评论

评论已关闭。

相关文章