位置编码的理解和梳理
引言
Transformer的悖论
自2017年横空出世以来,Transformer架构已然成为序列建模领域的一场革命,因其卓越的并行计算能力和捕捉全局依赖的强大性能而备受赞誉。它构成了当今最先进的大型语言模型(LLMs)的基石。然而,在这座宏伟的架构丰碑之下,隐藏着一个核心的悖论:在其最纯粹的形式下,这个强大的架构对其处理的数据的顺序一无所知。
应对置换不变性
这种“顺序盲视”源于其核心组件——自注意力(Self-Attention)机制固有的置换不变性(更准确地说是置换等变性)。举一个简单而有力的例子:对于一个原始的Transformer模型,“猫追逐狗”这句话与一堆无序的词汇集合 {猫, 追逐, 狗} 在语义上是无法区分的。这种情况明确地揭示了一个不可或缺的需求:必须有一种机制来为模型注入序列的顺序信息。这个至关重要的角色,正是由位置编码(Positional Encoding, PE)来扮演的。
深度探索路线图
本文旨在对Transformer中的位置编码技术进行一次全面而深入的剖析。我们将探讨四种主流的位置编码技术,将它们视为一场思想的演进之旅。这场旅程始于为每个词元(token)分配静态、绝对的“地址”,逐步发展到捕捉动态、相对的“关系”。通过这次探索,我们将揭示这些技术如何从根本上解决了Transformer的顺序感知问题,并最终成就了其在自然语言处理及更广泛领域中的霸主地位。
第一章:根本性问题——自注意力机制的置换不变性
要理解位置编码的必要性,首先必须深入剖析Transformer架构的心脏——自注意力机制,并揭示其对顺序信息的“漠不关心”是如何产生的。
解构自注意力机制
自注意力机制的计算核心是缩放点积注意力(Scaled Dot-Product Attention),其数学表达式如下:
其中,查询(Query, Q)、键(Key, K)和值(Value, V)这三个矩阵均由同一个输入序列通过不同的线性投影(权重矩阵 (W_q, W_k, W_v))生成。关键在于,对于输入序列中的每一个词元,模型都使用相同的权重矩阵进行变换,无论该词元处于序列的哪个位置。注意力得分是通过计算一个查询向量与所有键向量的点积来获得的,这个过程本质上是在衡量序列中任意两个词元之间的兼容性,而与它们的绝对或相对位置无关。
顺序无关性的数学证明
自注意力机制的顺序无关性可以通过一个简单的数学论证来揭示。假设我们有一个输入序列矩阵 (X),其中每一行代表一个词元的嵌入向量。如果我们用一个置换矩阵 (P) 来打乱 (X) 中各行的顺序,得到新的输入 (PX)。将这个新的输入送入一个不包含位置编码的Transformer层 (T),我们会发现其输出同样是被置换过的原始输出,即 (T(PX)=PT(X))。
这个性质被称为置换等变性(Permutation Equivariance)。它意味着模型对输入序列的处理方式,就如同处理一个无序的集合。核心的矩阵乘法((QK^T))和后续的softmax操作,其计算结果仅依赖于向量本身的内容,而与它们在输入矩阵 (X) 中的初始行索引(即位置)无关。
确立位置信息的不可或缺性
为了打破这种对称性,使模型能够感知并利用词序信息,我们必须将每个词元的位置信息明确地注入到其表示中。正是这次注入,将模型从处理“词袋”(Bag of Words)的模式,转变为理解结构化序列的模式。尽管有研究表明,多层的自回归模型可以在没有显式位置编码的情况下隐式地学习到顺序信息,但对于通用和基础的Transformer架构而言,显式的位置编码被认为是获得稳健性能的必要条件。
这种对位置编码的需求并非对模型缺陷的修补,而是为了优先实现并行化而做出的设计权衡的必然结果。在Transformer出现之前,循环神经网络(RNNs)和长短期记忆网络(LSTMs)等主流模型通过其固有的顺序处理结构,自然地编码了时序信息。然而,这种递归结构也造成了计算瓶颈,阻碍了大规模并行化。《Attention Is All You Need》这篇开创性论文的核心创新,正是彻底摒弃了递归结构,转而采用完全基于注意力的架构。这一架构选择在训练速度和直接建模远程依赖方面取得了巨大成功。然而,移除递归结构也同时移除了内置的顺序跟踪机制。因此,置换不变性这个“问题”,实际上是为了换取并行化优势而付出的、经过深思熟虑的代价。位置编码,则是为了重新获得在这种权衡中丢失的序列信息而精心设计的工程解决方案。
第二章:绝对位置编码(APE)——为词元赋予固定地址
本章将探讨第一类解决方案,它们为序列中的每个位置分配一个唯一的、绝对的位置向量。
2.1 正余弦编码:经典优雅的初始蓝图
背景与动机
这是在《Attention Is All You Need》论文中提出的最初的、无需参数的解决方案。设计者需要一种方法,它能为每个位置生成唯一的编码,能够处理可变长度的序列,并且有潜力泛化到比训练时所见序列更长的序列。
工作原理与公式
正余弦编码为每个位置 (pos) 和每个维度 (i) 生成一个位置编码值。其公式分为偶数维度和奇数维度两种情况:
这里的 (pos) 是词元在序列中的位置,(i) 是编码向量中的维度索引,(d_{text{model}}) 是模型的嵌入维度。
其背后的直觉是:每个位置 (pos) 都被映射到一个向量上,该向量的每个维度对应一个不同频率的正弦或余弦波。这为每个位置创造了一个独特的“数字指纹”。这些波的波长构成了一个从 (2pi) 到 (10000 cdot 2pi) 的几何级数,使得模型能够在不同尺度上捕捉位置信息。
隐藏的巧思:编码相对位置
正余弦编码的一个关键且精妙之处在于,它使得模型能够轻易地学习到相对位置信息。论文作者假设,对于任意固定的偏移量 (k),(PE_{pos+k}) 可以表示为 (PE_{pos}) 的线性函数。这源于三角函数中的和角公式:
由于这一特性,即使模型只被赋予了绝对位置信息,它也能通过学习一个线性变换来推断出词元之间的相对位置关系,这为注意力机制捕捉相对依赖提供了便利。
分析:优缺点与适用场景
优点:
- 确定性且无需训练,节省了模型参数和计算资源。
- 理论上,它允许模型外推到比训练期间遇到的更长的序列,因为它是一个连续函数,可以为任何位置生成编码。
缺点:
- 在实践中,其外推能力非常有限,当序列长度远超训练长度时,模型性能会显著下降。
- 它是一种静态的、“一刀切”的解决方案,对于所有任务都使用相同的编码方式,可能并非最优选择。
适用场景:
- 原始的Transformer模型实现。
- 用于教学目的的模型,因其简单明了。
- 对参数效率要求极高的场景。
2.2 可学习的绝对编码:数据驱动的方法
背景与动机
这种方法在BERT和GPT等模型中得到普及。其核心动机是让模型根据其特定的预训练任务和数据,自主学习出最优的位置表示,而不是依赖于一个固定的、人为设计的正余弦函数。
工作原理
可学习绝对编码在概念上非常简单。它通过创建一个位置嵌入矩阵 (P) 来实现,该矩阵的大小为 ((text{max_position_embeddings}, text{hidden_size})),其中 (text{max_position_embeddings}) 是模型支持的最大序列长度。
对于输入序列中位置为 (i) 的词元,模型会从该矩阵中检索出对应的向量 (P[i])。然后,将这个位置向量与其词元嵌入向量(以及在BERT中的段嵌入向量)相加,形成最终的输入表示。这个位置嵌入矩阵像词元嵌入矩阵一样,被随机初始化,并通过反向传播在训练过程中不断更新。
分析:优缺点与“外推之墙”
优点:
- 灵活性与表达能力:由于位置嵌入是学习得来的,模型可以根据语言的细微差别和特定任务的需求,定制出更有效的位置信息表示。
缺点:
- 硬性的外推限制:这是该方法最致命的缺陷。模型只学习了从0到 (text{max_position_embeddings} - 1) 这些位置的嵌入。它没有任何机制来为超出此范围的位置生成表示,从而为序列长度设置了一道无法逾越的“硬墙”。这使其成为Transformer长度泛化能力的主要瓶颈。
- 增加参数量:为模型引入了大量额外的可训练参数,增加了模型的存储和计算成本。
适用场景:
- 适用于最大序列长度固定且已知的模型,如BERT和早期的GPT系列模型。
- 不适合需要处理动态长度或超长上下文窗口的应用。
从正余弦编码到可学习绝对编码的转变,反映了深度学习领域一个核心的哲学思辨:是选择精巧的手工特征工程(如正余弦函数),还是拥抱端到端的学习。正余弦编码是一种特征工程的体现,其三角函数形式是一种精心设计,旨在以线性方式表示相对位置的归纳偏置。而可学习绝对编码则遵循了“让数据说话”的深度学习理念,假设模型能够从零开始学习到一种更优的、任务特定的位置表示。
然而,这种完全依赖数据驱动的方法也带来了新的问题。通过摒弃手工设计的数学结构,它失去了内在的、能够泛化的数学属性。模型学会了位置0到511的具体向量表示,但没有学到任何关于位置512应该是什么样子的基本原理。可学习绝对编码在外推能力上的失败,成为了一个重要的研究驱动力,推动社区去寻找那些既能结合学习的灵活性,又能具备数学结构的泛化能力的全新方法。这直接催生了相对位置编码(RPE)和旋转位置编码(RoPE)的兴起。
第三章:相对位置编码(RPE)——核心在于关系
随着模型处理的序列越来越长,绝对位置编码的局限性日益凸显。研究者们开始意识到,对于注意力机制而言,更重要的或许不是每个词元的绝对“门牌号”,而是它们之间的相对“距离”。
背景:向相对性的概念飞跃
相对位置编码的核心洞见在于:对于一个位于位置 (i) 的查询词元和一个位于位置 (j) 的键词元,它们之间的相对距离 (i-j) 往往比它们的绝对位置 (i) 和 (j) 更具信息量。原始Transformer在处理长文本时,通常会将其分割成固定长度的、互不关联的片段进行处理,这会导致上下文碎片化(Context Fragmentation)问题,即一个片段末尾的词元无法注意到前一个片段的信息,严重限制了模型捕捉长期依赖的能力。
Transformer-XL方案:实现范式的转变
Transformer-XL模型为了解决上下文碎片化问题,引入了片段级循环(Segment-Level Recurrence)机制,即缓存并重用前一个片段的隐藏状态作为当前片段的“记忆”。这一架构创新使得绝对位置编码不再适用。试想,如果继续使用绝对位置编码,那么片段 (tau) 中位置为 (j) 的词元将与片段 (tau+1) 中位置为 (j) 的词元拥有完全相同的位置编码,这会导致模型无法区分它们的时间顺序,造成时间混淆(Temporal Confusion)。
为了解决这个由新架构带来的新问题,Transformer-XL提出了一种全新的相对位置编码方案。
核心思想与公式解构
其核心思想是,不再将位置信息添加到输入嵌入中,而是直接在计算注意力分数时进行修正。在Transformer-XL中,查询 (q_i) 和键 (k_j) 之间的注意力得分 (A_{i,j}) 被分解为四个部分:
这个公式可以直观地解释为:
- (a) 内容-内容:与标准自注意力相同,基于内容的寻址。
- (b) 内容-位置:内容相关的位置偏置,即查询的内容会影响它对不同相对位置的偏好。
- (c) 位置-内容:全局内容偏置,衡量了键的内容本身的重要性。
- (d) 位置-位置:全局位置偏置,衡量了不同相对距离的重要性。
其中,(R_{i-j}) 是一个代表相对距离 (i-j) 的可学习的嵌入向量。通过这种方式,注意力得分直接取决于相对距离,从而完美地解决了片段级循环机制中的时间混淆问题,因为相对距离在跨越片段边界时是保持一致的。
分析:优缺点与架构影响
优点:
- 卓越的长度泛化能力:模型能够更好地处理比训练时更长的序列,有效捕捉长期依赖。
- 架构的赋能者:它是实现Transformer-XL等长上下文模型的关键技术,解决了上下文碎片化问题。
缺点:
- 实现复杂性:相比绝对位置编码,实现起来更为复杂。
- 计算开销:原始的实现方式需要一个大小为 (O(L^2 cdot D)) 的相对位置矩阵,带来了较大的内存开销,尽管后续研究提出了更高效的计算方法。
适用场景:
- 为长文档处理、长序列语言建模(如Transformer-XL)和文本摘要等任务设计的模型。
- T5模型家族也采用了一种简化的、将相对位置信息作为标量偏置项(scalar bias)添加到注意力分数中的RPE变体。
相对位置编码不仅仅是一种替代性的位置编码方法,它更是一种架构的赋能者。它的诞生与解决原始Transformer固定长度处理的局限性紧密相连,从而催生了像Transformer-XL这样的新架构。这一发展历程揭示了一个深刻的因果链条:对更长上下文的追求,推动了架构的创新(片段级循环);而架构的创新,又反过来催生了位置编码的革新(相对位置编码)。因此,RPE的成功不仅在于它能更好地编码位置信息本身,更在于它与一种更先进的模型架构之间形成的共生关系。
第四章:旋转位置编码(RoPE)——优雅的几何解决方案
在绝对和相对位置编码的探索之后,研究界寻求一种能够将二者优点集于一身的、更为优雅的统一框架。旋转位置编码(RoPE)应运而生,它通过一种新颖的几何视角,为位置信息的编码提供了全新的思路。
背景与动机
RoPE旨在将绝对位置信息和相对位置信息无缝地融合在同一个公式中。其设计目标是实现序列长度的灵活性,并使词元间的依赖关系能随着相对距离的增加而自然衰减,这是一种符合直觉的归纳偏置。
工作原理:位置的几何学
核心直觉
RoPE的核心思想是,不再通过向量相加的方式注入位置信息,而是根据词元的绝对位置,对其查询(Query)和键(Key)向量进行旋转。
数学公式
为了理解其工作原理,我们首先从二维空间入手。假设有一个二维向量 (x = (x_1, x_2)),RoPE会根据其位置 (m) 将其旋转一个角度 (mtheta)。这个旋转操作通过一个旋转矩阵 (R_{mtheta}) 实现。
RoPE的“魔法”在于,经过旋转后的查询向量 (q_m) 和键向量 (k_n) 之间的点积,其结果只与原始向量 (q)、(k) 以及它们的相对位置 (m-n) 有关。具体来说:
这意味着,尽管我们通过绝对位置 (m) 和 (n) 对向量进行了操作,但最终的注意力分数却天然地蕴含了相对位置信息 (m-n)。
在实际的高维实现中,RoPE会将 (d) 维的嵌入向量两两分组,形成 (d/2) 个二维子空间。然后,对每个子空间应用二维旋转,但每个子空间的旋转频率 (theta_i) 各不相同。这些频率通常设置为一个递减的几何级数,例如 (theta_i = 10000^{-2i/d})。这样,不同的维度对(子空间)就能在不同的频率上捕捉位置信息,从高频(捕捉邻近词元关系)到低频(捕捉远距离词元关系)。
分析:当前位置编码的王者
优点
- 卓越的外推能力:由于相对位置的依赖关系是通过三角函数的周期性属性来编码的,RoPE能够平滑地泛化到训练时未见过的位置,展现出极强的外推性能。
- 隐式的相对编码:它将相对位置依赖直接、优雅地融入自注意力公式中,无需像RPE那样显式地修改注意力得分矩阵。
- 自然的注意力衰减:旋转的几何特性天然地导致了注意力分数会随着相对距离的增加而衰减,这是一种非常有用的归纳偏置,符合语言中局部性原理。
- 无额外参数:与正余弦编码一样,RoPE是无参数的,不会增加模型的体积。
缺点
- 概念复杂性:相比于简单的加法操作,旋转变换在概念上更难直观理解。
- 超参数敏感性:模型的性能可能对旋转基底 (theta) 这个超参数的选择较为敏感。
适用场景
RoPE已成为当今许多性能最先进的大型语言模型(LLMs)中的事实标准,包括Llama系列、Gemma和PaLM等。
RoPE的出现标志着位置编码设计理念的成熟。它从加性/拼接式方法(Additive/Concatenative)演进到了乘性/变换式方法(Multiplicative/Transformative)。绝对位置编码的操作是 (text{输入} = text{词元嵌入} + text{位置嵌入}),位置信息是一个附加的偏移量。相对位置编码的操作是 (text{得分} = text{内容得分} + text{位置得分}),位置信息是注意力得分空间中的一个偏置项。而RoPE的操作是 (q' = f(q, pos)) 和 (k' = f(k, pos)),其中函数 (f) 是一个旋转变换。位置信息通过一个乘性的矩阵运算(旋转)在注意力点积之前施加。
这种转变是深刻的。RoPE不再是简单地给词元“贴上”位置标签,而是根据其位置在向量空间中重新定向其表示。向量空间的几何结构现在与序列的顺序交织在一起。这是一种对位置和内容更根本的融合,也正是这种深度融合,为其卓越的性能和泛化能力奠定了基础。
第五章:综合比较与未来展望
在深入探讨了四种主流的位置编码技术后,本章将进行一次全面的综合比较,并展望该领域的未来发展方向。
全景比较:一览无余的总结
为了帮助实践者在不同场景下做出明智的架构选择,下表将前述四种方法的关键特性进行了梳理和对比。这个表格将复杂的权衡关系提炼成一个易于理解的参考,清晰地展示了从外推能力到参数成本等各个维度的优劣。
| 特性 | 正余弦绝对位置编码 | 可学习绝对位置编码 | 相对位置编码 | 旋转位置编码 |
|---|---|---|---|---|
| 集成方式 | 加性(作用于输入嵌入) | 加性(作用于输入嵌入) | 加性(作用于注意力分数) | 乘性(旋转变换) |
| 可学习参数 | 否 | 是 | 是(用于嵌入或偏置) | 否 |
| 外推能力 | 理论可行,实践较差 | 否(存在硬性上限) | 良好 | 优秀(当前最佳) |
| 核心思想 | 固定的、确定性的函数映射 | 数据驱动的、可学习的映射 | 编码词元间的成对距离 | 通过旋转来编码位置 |
| 主要优点 | 简单、无参数 | 灵活性、高表达能力 | 强大的长度泛化能力 | 顶尖性能、卓越外推能力 |
| 主要缺点 | 静态、泛化能力有限 | 无外推能力、参数量大 | 实现相对复杂 | 概念相对复杂 |
| 代表模型 | 原始 Transformer | BERT, GPT-2 | Transformer-XL, T5 | Llama, Gemma, PaLM |
位置编码的演进轨迹
回顾整个发展历程,我们可以清晰地看到一条从静态到可学习、从绝对到相对、从加性到乘性的演进脉络。这个叙事弧线为我们理解该领域的历史和未来方向提供了一个强有力的心智模型。最初,研究者们试图用一个固定的数学公式来解决问题;随后,他们相信模型可以自己学到更好的表示;接着,他们发现相对关系比绝对坐标更重要;最终,他们通过一种更根本的几何变换,实现了性能和泛化能力的统一。
地平线之外:Transformer中秩序的未来
位置编码领域的研究远未结束,新的思想和方法仍在不断涌现。
- 无位置编码(NoPE):一些研究惊奇地发现,非常深的、仅包含因果注意力的解码器模型有时可以隐式地学习到位置信息,尽管这尚未成为主流的替代方案。
- 混合与高级方法:新的相对位置编码变体,如ALiBi、KERPLE和FIRE,正在不断优化相对偏置的思想,以期获得更好的性能和外推能力。
- 超越一维文本:位置编码的概念正在被扩展到二维或三维数据中,例如在视觉Transformer(ViTs)或图神经网络中,如何为图像块或图节点设计有效的位置编码,已成为一个活跃的研究领域。
结论
回顾:不可或缺的秩序缔造者
位置编码并非Transformer架构的一个辅助组件,而是其成功的基石之一。它赋予了原本对顺序无感的自注意力机制感知和利用序列结构的能力,是模型从理解“什么”到理解“如何排列”的关键。
最终提炼
我们可以用一句话来概括每种技术的核心贡献:
- 正余弦绝对位置编码:提供了一个基于数学原理的、优雅的初始概念验证。
- 可学习绝对位置编码:以牺牲泛化能力为代价,引入了任务特定的灵活性。
- 相对位置编码:通过关注关系而非地址,解锁了长上下文建模的能力。
- 旋转位置编码:通过将位置信息编织到嵌入空间的几何结构中,达到了性能和优雅的新高度。
代码实现
正余弦绝对位置编码代码实现
import torch import torch.nn as nn import math class SinusoidalPositionalEncoding(nn.Module): """ 标准正余弦位置编码 """ def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): """ Args: d_model (int): 模型的维度(嵌入向量的维度) max_len (int): 序列的最大长度 dropout (float): Dropout 的概率 """ super().__init__() self.dropout = nn.Dropout(p=dropout) # 创建一个足够大的位置编码矩阵,形状为 (max_len, d_model) pe = torch.zeros(max_len, d_model) # 创建一个位置张量,形状为 (max_len, 1) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 计算分母部分,div_term 的形状为 (d_model/2,) # 10000^(2i/d_model) = e^(2i * -log(10000) / d_model) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # 计算偶数维度的 sin 编码 pe[:, 0::2] = torch.sin(position * div_term) # 计算奇数维度的 cos 编码 pe[:, 1::2] = torch.cos(position * div_term) # 增加一个 batch 维度,使其形状变为 (1, max_len, d_model) # register_buffer 使得 pe 成为模型的一部分,但不是模型的参数,不会被梯度更新 self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): 输入的嵌入向量,形状为 (batch_size, seq_len, d_model) Returns: torch.Tensor: 增加了位置信息的嵌入向量 """ # x 的形状是 (batch_size, seq_len, d_model) # self.pe 的形状是 (1, max_len, d_model) # 我们截取序列长度部分的位置编码并加到 x 上 x = x + self.pe[:, :x.size(1), :] return self.dropout(x) # --- 使用示例 --- d_model = 512 max_seq_len = 100 batch_size = 32 # 实例化位置编码模块 pos_encoder = SinusoidalPositionalEncoding(d_model=d_model, max_len=max_seq_len) # 创建一个假的输入张量 input_tensor = torch.randn(batch_size, max_seq_len, d_model) # 应用位置编码 output_tensor = pos_encoder(input_tensor) print("输入张量形状:", input_tensor.shape) print("输出张量形状:", output_tensor.shape)
旋转位置编码实现
import torch import torch.nn as nn import math # ============================================================================== # 步骤 1: RoPE (旋转位置编码) 的实现 # 这部分与之前的回答相同,用于生成和应用旋转矩阵。 # ============================================================================== def precompute_rope_frequencies(head_dim: int, max_len: int, theta: float = 10000.0): """ 预计算 RoPE 的频率(角度)。 这是在模型初始化时调用一次即可,无需在每次前向传播时都计算。 Args: head_dim (int): 每个注意力头的维度。RoPE 是在头的维度上应用的。 max_len (int): 模型的最大序列长度。 theta (float): RoPE 中的基数 theta,通常是 10000。 Returns: torch.Tensor: 复数形式的频率,形状为 (max_len, head_dim / 2)。 """ # 确保头的维度是偶数 assert head_dim % 2 == 0, "head_dim must be even" # 计算 theta_i = 1 / (theta^(2i/d)) # 形状: (head_dim / 2,) inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) # 计算位置 m # 形状: (max_len,) t = torch.arange(max_len, dtype=torch.float) # 计算 m * theta_i # 使用外积 (outer product) 得到形状为 (max_len, head_dim / 2) 的矩阵 freqs = torch.einsum('i,j->ij', t, inv_freq) # 将 freqs 转换为复数形式 e^(i * m * theta_i) # torch.polar 创建一个复数张量,其元素大小为第一个参数,角度为第二个参数 # 形状: (max_len, head_dim / 2) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_embeddings(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ 将旋转位置编码应用到输入张量上(Query 或 Key)。 Args: x (torch.Tensor): 输入张量 (Q 或 K),形状为 (B, H, L, D_h), 其中 B=batch_size, H=num_heads, L=seq_len, D_h=head_dim。 freqs_cis (torch.Tensor): 预计算的复数频率,形状为 (L, D_h / 2)。 Returns: torch.Tensor: 应用了 RoPE 后的张量,形状与输入 x 相同。 """ # 将 x 的最后一维 (head_dim) 转换为复数形式 # (B, H, L, D_h) -> (B, H, L, D_h/2, 2) x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) # (B, H, L, D_h/2, 2) -> (B, H, L, D_h/2) (复数) x_complex = torch.view_as_complex(x_shaped) # freqs_cis 需要扩展维度以匹配 x_complex 的形状进行广播 # (L, D_h/2) -> (1, 1, L, D_h/2) freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) # 复数乘法应用旋转 # (B, H, L, D_h/2) * (1, 1, L, D_h/2) -> (B, H, L, D_h/2) x_rotated_complex = x_complex * freqs_cis # 将结果转换回实数形式 # (B, H, L, D_h/2) -> (B, H, L, D_h/2, 2) x_rotated = torch.view_as_real(x_rotated_complex) # (B, H, L, D_h/2, 2) -> (B, H, L, D_h) x_out = x_rotated.reshape(*x.shape) return x_out.type_as(x) # ============================================================================== # 步骤 2: 实现集成了 RoPE 的多头注意力模块 # ============================================================================== class MultiHeadAttentionWithRoPE(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # 线性投影层 self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) self.W_o = nn.Linear(d_model, d_model, bias=False) # Dropout 层 self.dropout = nn.Dropout(0.1) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor = None): """ Args: x (torch.Tensor): 输入张量,形状 (B, L, D_m), 其中 B=batch_size, L=seq_len, D_m=d_model。 freqs_cis (torch.Tensor): 预计算的 RoPE 频率,形状 (L, D_h/2)。 mask (torch.Tensor, optional): 注意力掩码。Defaults to None. Returns: torch.Tensor: 注意力模块的输出,形状 (B, L, D_m)。 """ batch_size, seq_len, _ = x.shape # 1. 线性投影: (B, L, D_m) -> (B, L, D_m) query = self.W_q(x) key = self.W_k(x) value = self.W_v(x) # 2. 拆分成多个头: (B, L, D_m) -> (B, H, L, D_h) query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 3. 对 Query 和 Key 应用 RoPE # 我们只截取当前序列长度所需的频率 query = apply_rotary_embeddings(query, freqs_cis=freqs_cis[:seq_len]) key = apply_rotary_embeddings(key, freqs_cis=freqs_cis[:seq_len]) # 4. 计算注意力分数 # (B, H, L, D_h) @ (B, H, D_h, L) -> (B, H, L, L) attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: attention_scores = attention_scores.masked_fill(mask == 0, float('-inf')) attention_weights = torch.softmax(attention_scores, dim=-1) attention_weights = self.dropout(attention_weights) # 5. 应用注意力权重到 Value # (B, H, L, L) @ (B, H, L, D_h) -> (B, H, L, D_h) attention_output = torch.matmul(attention_weights, value) # 6. 合并多头并进行最终线性投影 # (B, H, L, D_h) -> (B, L, H, D_h) -> (B, L, D_m) attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # (B, L, D_m) -> (B, L, D_m) output = self.W_o(attention_output) return output # ============================================================================== # 步骤 3: 构建一个完整的 Transformer Block # ============================================================================== class RoPETransformerBlock(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int): super().__init__() self.attention = MultiHeadAttentionWithRoPE(d_model, num_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(0.1) self.dropout2 = nn.Dropout(0.1) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: torch.Tensor = None): # 注意力层 + 残差连接和层归一化 attn_output = self.attention(self.norm1(x), freqs_cis, mask) x = x + self.dropout1(attn_output) # 前馈网络层 + 残差连接和层归一化 ffn_output = self.ffn(self.norm2(x)) x = x + self.dropout2(ffn_output) return x # ============================================================================== # 步骤 4: 运行示例 # ============================================================================== if __name__ == '__main__': # --- 模型超参数 --- d_model = 512 # 模型总维度 num_heads = 8 # 注意力头的数量 d_ff = 2048 # 前馈网络中间层维度 max_len = 1024 # 模型支持的最大序列长度 # --- 输入数据参数 --- batch_size = 4 seq_len = 256 # 当前输入的序列长度 (可以小于 max_len) # --- 实例化模型 --- model = RoPETransformerBlock(d_model=d_model, num_heads=num_heads, d_ff=d_ff) print("模型结构:n", model) # --- 预计算 RoPE 频率 --- # 这一步通常在模型初始化后、训练开始前完成 head_dim = d_model // num_heads freqs_cis = precompute_rope_frequencies(head_dim, max_len) print(f"n预计算 RoPE 频率,形状: {freqs_cis.shape}") # --- 创建一个假的输入张量 --- # 形状: (batch_size, seq_len, d_model) input_tensor = torch.randn(batch_size, seq_len, d_model) print(f"输入张量形状: {input_tensor.shape}") # --- 创建一个上三角掩码 (用于自回归任务,如 GPT) --- # 如果是 BERT 类型的任务,则不需要这个掩码 # 形状: (1, 1, seq_len, seq_len) mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0) print(f"注意力掩码形状: {mask.shape}") # --- 前向传播 --- output_tensor = model(input_tensor, freqs_cis=freqs_cis, mask=mask) # --- 打印输出 --- print(f"输出张量形状: {output_tensor.shape}") # --- 验证输出形状是否正确 --- assert output_tensor.shape == input_tensor.shape print("n✅ 范例运行成功,输出形状正确!")