探秘Transformer系列之(11)— 掩码

探秘Transformer系列之(11)--- 掩码

0x00 概述

机器学习领域中,掩码(Mask)本质是一个跟需要掩盖的目标张量大小一致的(大多数是0-1二值)张量,其思想最早起源于 word2vec 的CBOW的训练机制:通过上下文来预测中心词。掩码就相当于把中心词给遮掩住。不同的任务和应用场景可能需要不同类型的mask操作。在自注意力模型中,常见的mask操作有两种:Padding mask和Sequence mask。

  • Padding mask(填充掩码):在处理变长序列时,为了保持序列的长度一致,通常会在序列的末尾添加一些特殊的填充符号(如)。Padding mask的作用是将这些填充符号对应位置的注意力分数设为一个很小的值(如负无穷),从而使模型在计算注意力分数时忽略这些填充符号,避免填充符号对计算产生干扰。

  • Sequence mask(序列掩码):在某些任务中,为了避免模型在生成序列时看到未来的信息,需要对注意力分数进行掩码操作。Sequence mask的作用是通过构建下三角(或者上三角)的注意力分数矩阵,将当前位置之后位置的注意力分数设为一个很小的值,从而使模型只关注当前 token 与之前 token 的注意力关系,不理会它与后续 token 的关系。这样可以保证模型在生成序列时只依赖于已经生成的部分,不会受到未来信息的影响,即只”看”当前及前面的 tokens。也有把Sequence mask叫做Casual Mask的。

使用掩码的自注意力机制就叫做掩码自注意力机制,也被称为因果自注意力(Causal Self-Attention)。

探秘Transformer系列之(11)--- 掩码

0x01 需求

我们先来仔细分析一下为何需要掩码。

1.1 避免偏差

实际情况

在神经网络的训练过程中,同一个batch会包含有多个文本序列,不同的序列长度并不一定会一致。而神经网络的输入需要一个规整的张量。为了符合模型的输入方式,在数据集的生成过程中,我们要对输入序列进行对齐,使同一个batch内所有序列的长度一致。具体来说就是:

  • 但是如果输入的序列太长,我们会截取左边的内容,把多余的单词直接舍弃。
  • 在较短的序列后面用特殊符号来填充(比如填)。

具体参见下图。图上展示了将多个句子组成一个batch时会遇到的情况:句子的长度是不同的。我们要对所有的句子按照预先设定的最长长度做填充或者裁剪,形成多个长度一样的句子,才能组成batch(一个三维的张量),送入模型进行训练。

探秘Transformer系列之(11)--- 掩码

问题所在

上述方案在注意力计算时会遇到问题:因为如果在注意力的计算过程中考虑到填充位置上的信息,则会给最终结果带来误差。我们来具体分析下。

假设某一行向量是 ([x_1, x_2, ..., x_V]),行向量中某一个元素是(x_i),原生softmax的计算公式如下:

[softmax(x_i) = frac{e^{x_i}}{sum _{j=1}^V e^{x_j}} ]

算法流程需要两个循环,首先需要迭代计算分母的和,然后再迭代计算向量中每一个值对应的softmax值,即缩放每一个元素。因为填充词是人为添加的,其实没什么意义,在计算注意力分数时,我们通常不希望模型将注意力放在这些无关紧要的填充的词上,不要浪费计算资源。我们也不希望这些位置参与后期的反向传播过程。以此避免最后影响模型自身的效果。然而实际上,padding数值一般来说是0,(e^0)的数值为1,因此softmax中被padding的部分就参与了运算。这些无效部分参与运算会产生很大隐患,会导致注意力分数会出现偏差,影响全局概率值。所以我们需要进行一些处理。

解决方案

直觉的解决方案是:模型应该把注意力聚焦在实际有意义的词上,所以我们要找到所有非填充(nonpad)token,然后只计算这些非填充token的损失函数。当然我们也可以反向思考,用mask让这些无效区域不参与运算。这就是padding mask。

1.2 防止偷看

实际情况

首先,我们回忆下注意力计算公式如下,我们需要针对整个输入序列进行注意力计算。

[Attention(Q,K,V)=softmax(frac{QK^T}{sqrt d_k})V ]

其次,编码器和解码器的运行方式不同:

  • Encoder因为要编码整个句子,每个词都需要考虑上下文的关系。所以每个词在计算的过程中都是可以看到句子中所有词的。
  • 但是Decoder实质上是一个单向的自注意力结构,每个词都只能看到前面词的状态。原因如下:推理阶段是自回归模式,是一个词一个词输入的,Decoder是不知道下文信息的。所以每次decoder只能看到之前自己生成的token和prompt,因此自然也无法计算得到当前词和下文还没出现词的注意力。

解码器这种运行方式导致其在训练时候需要做特殊处理。因为训练阶段采用自回归模式,会导致训练速度过慢。如前文所述,为了加快训练速度,人们采用了Teacher Foring。即采用类似编码器中的矩阵并行算法,一步就把所有目标单词预测出来。这样做有两个好处,一是通过多样本并行计算能够加快网络的训练速度;二是在训练过程中直接喂入解码器正确的结果而不是上一时刻的预测值(因为训练时上一时刻的预测值可能是错误的),可以让训练更快收敛。

我们暂时先忘记Teacher Forcing,假定我们就是要进行并行计算。最朴素的训练方法应该是基于一个长为 n 的预测序列来构造 n 条样本(每个样本就是完整的预测序列),把这些样本并行输入到模型。对于第一个样本,模型就根据来预测第一个字符,对于最后一条样本,模型则根据前 n - 1 个字符去预测第 n 个字符。label中已经给出的这个时刻之后的teacher token。这样解码器就可以像编码器一样进行并行计算。即一次接收解码时所有时刻的输入,然后同步预测每个位置上的token。

探秘Transformer系列之(11)--- 掩码

问题所在

目前每个样本实际包括了整个句子。但是Decoder 在解码第 t 个时刻的时候只能使用 1...t 时刻的输入,而不能使用 t+1 时刻及其之后的输入,即模型只应该依据部分输入来进行预测。把整个句子(完整的目标序列)一次性输入给解码器就是问题所在。因为模型已经知道了全部句子内容。因此,在预测某个位置的词时,解码器可以使用该词之前的目标词以及该词之后的目标词。这使得解码器可以通过使用未来 "时间步 "的目标词来 "作弊"。比如基于”我爱“来预测”我爱中国“。在输出爱的时候,模型会用到后面“中国”的信息。

俗话说“天机不可泄露”。要是模型能未卜先知地知道自己下一步将要输出什么,模型很容易学会偷懒,它就不用费劲计算这个输出了,只需要把输入目标序列的下一元素作为输出就可以了,这样训练就没有效果。另外,因为attention layer是有很多层的。在第一层,当前token (X_{n})融合了下一个token (X_{n+1})的信息,在下一层attention layer计算时,token (X_{n+1})会看到(X_{n})中包含的(X_{n+1}), 这样在预测token (X_{n+1})的时候,使用自己的信息预测自己,这显然也是一种信息泄露。

所以我们在训练的时候,解码器不应该提前知道下文的信息,不能计算当前词和后面的词的注意力,而只能计算当前词和前面所有词的注意力。

解决方案

为了确保模型在这一时点上不会受到未来词汇的干扰,解码器采用了sequence mask。 其作用就是在 time_step 为 t 的时刻,把 t 之后的信息给隐藏起来。使得 decoder 只能看到目标序列的一部分(前缀),不能看见未来的信息。即对于一个序列,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。这就是Sequence mask。可以将这个过程想象为一个时间线:在预测一个特定的词时,你不能“预知”它之后的词汇,因为在实际情境中,之后的部分尚未发生。

探秘Transformer系列之(11)--- 掩码

总结一下,Padding Mask的作用是避免填充符号带来的偏差。Sequence mask的作用是屏蔽未来信息,防止偷看,保证每个位置只能看到前面的tokens。

0x02 Padding Mask

我们接下来看看Padding Mask如何实现。

2.1 逻辑

核心逻辑就是让填充词在经过softmax操作不应该有对应的输出。

掩码矩阵

研究人员找到的一个方法就是在训练时使用掩码矩阵。对于已经padding到同一长度的一个batch中的句子,在应用softmax函数之前,使用掩码矩阵把将补全的位置给掩盖掉。掩码矩阵有不同实现方式:

  • 矩阵每个值都是一个 Boolean,值为 false 的地方就是我们要进行处理的地方。
  • 在掩码矩阵中,填充词的对应位置放置一个非常大的负数(如-1e9),否则放置0。

在经过掩码矩阵处理之后,这些被掩盖位置在经过softmax激活函数后,得到的注意力分数(Attention Score)会归零或者接近于0,这样对应位置的token表征就不参与上文提到的按照权重加和的过程,即没有注意力分配到这个上面,不再影响全局概率的预测。

计算注意力步骤

加入mask之后的注意力计算的具体步骤如下:

  • 创建一个掩码矩阵。如果输入序列中的某个位置是填充词,则在掩码矩阵的对应位置放置一个非常大的负数(如-1e9),否则放置0。

  • 将掩码矩阵加到注意力分数上。因为掩码矩阵中填充词的位置是非常大的负数,加上它们之后,这些位置的注意力分数也会变成非常大的负数。

  • 应用softmax函数。在加了掩码的注意力分数上应用softmax函数。由于填充词位置的分数是非常大的负数,经过softmax函数后,这些位置的权重将接近于0,而其他位置的权重将保持不变(因为softmax是一个归一化函数)。

  • 计算加权和。使用softmax的输出作为权重,计算值(Value)的加权和。

下图中,上方是编码器输入对应的掩码操作,下方是解码器输入对应的掩码操作。

探秘Transformer系列之(11)--- 掩码

2.2 实现

我们来分析哈佛代码,为了更好的说明,我们把padding的代码一起加入进来。

设置填充符号

我们以目标句子为例,在数据加载时,collate_batch()函数会为目标句子加入掩码。

def collate_batch(     batch, # 句子对的列表     max_padding=128, # 句子最大长度     pad_id=2, ):         	# 省略其它代码                  processed_tgt = torch.cat( # 获取目标句子             [                 bs_id,                 torch.tensor(                     tgt_vocab(tgt_pipeline(_tgt)),                     dtype=torch.int64,                     device=device,                 ),                 eos_id,             ],             0,         )                  """         调用torch.pad()函数对processed_src进行处理,如果processed_src的长度小于max_padding,则使用pad_id进行填充,如果大于max_padding,则截断。         然后把处理后的processed_tgt加入到tgt_list。         """         tgt_list = []         tgt_list.append(             pad(                 processed_tgt,                 (0, max_padding - len(processed_tgt)),                 value=pad_id,             )         )  		# 省略其它代码 

建立mask

此处把Batch类中关于mask的部分拿出来再进行分析。生成src_mask的语句比较简单,只有self.src_mask = (src != pad).unsqueeze(-2) 这一行代码,其主要起到两个作用:

  • 把src中非pad的部分置为True,pad部分置为False。假设某个句子内容是[0, 3, 1, 2, 2],则其对应的掩码是[True, True, True, False, False]。“”、“”和“”算作句子成分,因此不做掩码处理。
  • 使用unsqueeze()函数增加一维度,因为后续src_mask要和注意力分数进行掩码计算,而注意力分数是三个维度,所以这里要保持一致。最终src_mask返回的是一个布尔矩阵,其形状是[批量大小,1,句子最长长度]。其中第i行第j列表示的是query的第i个词对key的第j个词的注意力是否无意义。若无意义则为True,有意义的为False(即被padding的位置是True)。后续在处理mask时,为False的位置是需要被mask掉的,True的位置是不需要动的。处理之后,占位符就无法吸收到query的注意力。
class Batch:     def __init__(self, src, tgt=None, pad=2):  # 2 = <blank>         self.src = src # 源语言句子列表,形状是[batch_size,Length]         # 创建源语言的掩码,这样可以忽略填充部分,unsqueeze()的作用是增加一维度,因为后续要和注意力分数进行掩码计算,而注意力分数是三个维度,所以这里要保持一致。         # (src != pad)返回一个等大的布尔张量,src元素等于pad的位置为False,否则为True         # unsqueeze(1)作用是增加了一个维度,变成pad_attn_mask: [batch_size,1,seq_len]         # 最终得到返回一个[batch_size, 1, seq_len]大小的布尔张量,False是需要mask掉的位置         self.src_mask = (src != pad).unsqueeze(-2)  

实施mask

具体应用掩码矩阵的代码位于 attention()函数中。注意,此时是把padding mask和sequence mask都在一起应用。

def attention(query, key, value, mask=None, dropout=None):     "Compute 'Scaled Dot Product Attention'"          # 先计算注意力分数     d_k = query.size(-1)     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)          # 在query和key的转置相乘得出(len_q,len_k)这个注意力分数矩阵以后,使用mask来掩盖相乘结果矩阵,此处把创建掩码矩阵和应用掩码矩阵合二为一     if mask is not None:         # 如果发现mask是0,就用-1e9来替换它         scores = scores.masked_fill(mask == 0, -1e9)              # 然后才开始实施softmax操作         p_attn = scores.softmax(dim=-1)     if dropout is not None:         p_attn = dropout(p_attn)     return torch.matmul(p_attn, value), p_attn 

0x03 Sequence mask

3.1 逻辑

Sequence mask的核心逻辑是:解码的时候掩盖掉当前时刻之后的信息。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。Sequence mask操作只针对自回归模型的训练过程和推理时的 prefill 阶段,推理时的 decode 阶段无需应用 mask 操作。但是因为方便实现,代码依然使用同一套。

掩码矩阵

我们需要产生一个Mask 矩阵,在计算注意力的时候,加入这个掩码(mask)。通过设计合适的mask,就可以实现在输出每一个元素的时候,切断它从未来获得信息的通路(把对应的注意力强制置零),这样就可以屏蔽或限制模型在计算注意力分数时对某些位置的关注。这个mask矩阵的特点如下:

  • 该矩阵的形状跟注意力分布矩阵一样,尺寸为 [seq_len, seq_len]。
  • 从矩阵内容上来看,这是一个下三角矩阵。内容依据实际情况而定,如果是布尔矩阵,可以上三角的值全为 0,下三角的值全为1,对角线也是 1。如果是浮点矩阵,可以把上三角的值赋值为负无穷。这样可以单独调节每一个源元素与每一个目标元素之间的注意力强度。
  • 在进行softmax计算之前,把这个矩阵作用在每一个序列上。即在 (QK^T) 点积上施加掩码,被屏蔽的元素被设置为负无穷大(表示它们是“无限不相似”的,即互不相关)。就是让query(t)和未来时刻的key的内积值为负无穷大(-inf)。
  • 在作Softmax操作时,模型会把这些负无穷大值所对应的权重变成零。后续再乘V的时候,当前的位置就无法看到后面的词信息了。所以计算t时刻概率时只用到了t-1以前时刻的key-value对的信息。

通过这个操作,我们就可以一次性计算整个Decoder输出序列的损失,而不用逐个Token计算,这个过程就是我们之前提到的Teacher Forcing。

Mask 矩阵示例如下,这是个10维度的下三角矩阵。当解码第一个字的时候,第一个字只能与第一个字计算相关性,当解出第二个字的时候,只能计算出第二个字与第一个字和第二个字的相关性。

[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],  [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],  [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],  [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],  [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],  [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],  [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],  [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],  [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] 

具体公式如下。

探秘Transformer系列之(11)--- 掩码

掩码自注意力

我们接下来再看看Masked Self-Attention。在解码Decoder Block中,输入序列首先遇到的是Masked Self-Attention(所谓Masked,即遮蔽的意思)。Masked Self-Attention的Q,K,V均来自同一个部分,满足max_len_q = max_len_k_v = max_len。masked multi-head self-attention与上面描述的multi-head Attention计算过程的不同之处在于score矩阵送入到softmax计算weight矩阵先进行一步mask操作。即句子中的每个词,都只能对包括自己在内的前面所有词进行Attention,这实质是单向Transformer。也向我们展示了Masked Self-Attention的设计动机:防止模型看到未来时刻的输入,也保证了训练时和预测时解码器运行的情况是一样的。

我们用第一个解码器层来解释其操作序列如下。

  • 经过Input embedding和位置编码之后,得到词嵌入(mathbf{X})

  • (mathbf{X})分别乘以三个权重矩阵,(mathbf{W^q}),(mathbf{W^k}),(mathbf{W^v}),经过三次线性变化,得到(mathbf{Q}),(mathbf{K}),(mathbf{V})矩阵。

  • (mathbf{Q})矩阵乘以(mathbf{K})矩阵的转置矩阵,得到(mathbf{QK^T}),即注意力分数分布。

  • (mathbf{QK^T})乘以一个Mask矩阵,按位相乘,得到遮蔽的注意力分数分布((Masked mathbf{QK^T})),保存此次解码应该看到的,隐藏看不到或者不应该看到的。即保持score矩阵的下三角部分不变,将上三角部分全部mask掉,置为负无穷。这样处理后,score矩阵的第i行,即q对应的第i个时间步,只保留了q与前i个时间步的k的关系得分,后面的部分都被mask掉了。

  • (Masked mathbf{QK^T})经过softmax操作,得到(mathbf{A} = text{softmax}(mathbf{Q}mathbf{K}^top / sqrt{d_k}))。显然被mask掉(置为-inf)的部分经过softmax处理都变成了0(无限接近0),即weight矩阵的第i行中,前i个权重之和为1,后面的权重都为0。

  • (mathbf{A})乘以(mathbf{V})矩阵,最终得到(mathbf{Z})矩阵。将mask后的weight与V矩阵相乘。前面的讨论已经知道,(mathbf{Z})矩阵的第i行,是V中的所有行基于weight矩阵第i行中的各个权重进行加权平均的结果,然而经过mask处理,weight矩阵的第i行中只剩下了前i个权重值,也就是说,context矩阵的第i行实际上是由V的前i行加权平均的结果。

    [mathbf{Z_i} = sum_{j=1}^{max_len} weight(i,j) cdot v(j) = sum_{j=1}^i weight(i,j) cdot v(j) ]

    此外,Y中每句话的第一个单词是开始符号的编码,所以Y中实际信息的时间步向前错了一位,因此,在masked multi-head self-attention结构中,计算第i个时间步的context信息时实际上只是使用前i-1个时间步的信息。

  • 以上说的是单一注意力头得到的矩阵(mathbf{Z_i}),如果是多头注意力,则把多个(mathbf{Z_i})拼接之后,经过线性变换,得到最终的(mathbf{Z})矩阵。

探秘Transformer系列之(11)--- 掩码

交叉注意力

现在思考一个问题:masked attention后面的cross-attention需要也加一个attention mask吗?答案是不需要。

解码器内部的带有mask的MultiHeadAttention的qkv向量输入来自目标单词嵌入或者前一个解码器输出,三者是相同的,但是后面的MultiHeadAttention的qkv向量中的kv来自最后一层编码器的输入,而q来自带有mask的MultiHeadAttention模块的输出。因为encoder可以看到整条输入序列,已经获得了全量信息,所以decoder这一条Q可以看到context vector全部的K和V。换句话说,在训练和预测的时候,我们是允许decoder看到目标序列输入的全部信息的,这些信息不需要 mask。但是在实际操作时还是需要加一个src_mask,就是源语言的padding mask。

总结下,对于解码器,实际操作会将两种掩码合并,每个位置取最小值,相当于两个掩码只要有任意一种情况需要被遮蔽,则就应该被遮蔽。具体可以参见下图。

探秘Transformer系列之(11)--- 掩码

3.2 实现

生成掩码

此处把Batch类中关于mask的部分拿出来再进行分析。

生成src_mask的语句比较简单,只有self.src_mask = (src != pad).unsqueeze(-2) 这一行代码。 具体如上面Padding mask实现中解析,这里不再赘述。

生成tgt_mask则比较复杂,具体逻辑在make_std_mask()函数中。tgt_mask与src_mask略有不同,除了需要盖住pad部分,还需要将对角线右上的也都盖住。就是要结合填充词对应的掩码和未来词汇相关的掩码。make_std_mask()函数的逻辑如下:

  • 首先生成填充词对应的掩码。假设某个句子内容是[0, 3, 1, 2, 2],则其对应的掩码是[True, True, True, False, False]。

  • 然后调用subsequent_mask()函数来生成未来词汇相关的掩码,这是一个对角线以及之下都是True的矩阵,具体掩码如下。

    [[   [ True, False, False, False, False ],   [ True, True, False, False, False ],   [ True, True, True, False, False ],   [ True, True, True, True, False ],   [ True, True, True, True, True ], ]] 
  • 最后填充词对应的掩码和未来词汇相关的掩码会做与操作,得到最终掩码如下

    [[   [ True, False, False, False, False ],   [ True, True, False, False, False ],   [ True, True, True, False, False ],   [ True, True, True, False, False ],   [ True, True, True, False, False ], ]] 

make_std_mask()函数的源码如下。

@staticmethod def make_std_mask(tgt, pad):     "Create a mask to hide padding and future words."          # 生成填充词对应的掩码,用于忽略填充部分     tgt_mask = (tgt != pad).unsqueeze(-2) # 创建目标语言的掩码,用于忽略填充部分          """     subsequent_mask()函数会生成未来词汇相关的掩码。然后再和tgt_mask进行与操作,得到最终掩码     tgt.size(-1) 表示的是序列的长度     """     tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(         tgt_mask.data     )     return tgt_mask 

subsequent_mask()函数的源码如下。

def subsequent_mask(size):     """     Mask out subsequent positions.     该方法在会在构建tgt的mask时使用。     """              # 首先需要定义掩码张量的形状,具体会生成一个Shape为(1, size, size)的矩阵     # 前面加个1是为了和tgt的维度保持一致,因为tgt的第一维是batch_size             attn_shape = (1, size, size)             # 首先使用torch.triu()函数产生一个上三角阵,几个注意点是:     # 1. diagonal=1意为不包含主对角线(从主对角线向上偏移1开始)        # 2. 使用np.ones方法向矩阵中添加1元素,形成上三角阵(左上角全为1)     # 3. 为了节约空间, 使上三角阵的数据类型变为unit8          subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(         torch.uint8     )      # subsequent_mask == 0其实是做了一个三角阵的反转, subsequent_mask中的每个元素都会被1减,这样将 0全部变为True, 1变为False     return subsequent_mask == 0 

我们打印输出看看。print(subsequent_mask(5))的结果如下。

tensor([[[ True, False, False, False, False],          [ True,  True, False, False, False],          [ True,  True,  True, False, False],          [ True,  True,  True,  True, False],          [ True,  True,  True,  True,  True]]]) 

它输出的是一个方阵,该方阵对角线与左下全为True,右上全为False。第一行只有第一列是 True,它的意思是时刻 1 只能 attend to 输入 1,第三行说明时刻 3 可以 attend to 1,2,3 而不能attend to 4,5 的输入,因为对于 Decoder 来说,这是属于未来的信息。

施加掩码

和前面padding mask是合并在一起施加的,此处不再赘述。

3.3 Transformer

我们再来看看Transformer的代码,基本和哈佛思路一致,只是加上了kv cache。

# Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(     input_ids_shape: torch.Size,     dtype: torch.dtype,     device: torch.device,     past_key_values_length: int = 0, ):     """     Create a causal mask for bi-directional self-attention.      Args:         input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).         dtype (torch.dtype): The data type of the mask.         device (torch.device): The device on which the mask will be placed.         past_key_values_length (int, optional): The length of past key values. Default is 0.      Returns:         torch.Tensor: The causal mask tensor.     """     bsz, tgt_len = input_ids_shape     mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)     mask_cond = torch.arange(mask.size(-1), device=device)     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)     mask = mask.to(dtype)      if past_key_values_length > 0:         mask = torch.cat(             [                 torch.zeros(                     tgt_len, past_key_values_length, dtype=dtype, device=device                 ),                 mask,             ],             dim=-1,         )     return mask[None, None, :, :].expand(         bsz, 1, tgt_len, tgt_len + past_key_values_length     ) 

0x04 数据流

哈佛代码中通过两个变量把两种掩码做了糅合,又加上编码器和解码器两个模块的排列组合,让人难以理解。我们再仔细梳理下数据流程。总的来说,对于两种掩码,其在编码器和解码器两个模块中的需求如下:

  • 对于Encoder来说,不应该注意的部分,因为这不属于句子成分。但是不需要防止“窥视未来信息”。
  • 对于Decoder来说,前面的词不应该注意后面的词,同时,也不能注意的部分。padding mask 和sequence mask是可以同时存在的。

我们再给出一个表格,大家可以看到在代码中两个变量的特性。

变量名称 mask类型 编码器Self-attention 解码器masked self-attention 解码器Cross-attention
src_mask Padding Mask 使用 不使用(padding的功能在tgt_mask中完成) 使用
tgt_mask Padding Mask + Sequence Mask 不使用 使用 不使用

4.1 如何应用于注意力

我们首先看看两种掩码在逻辑上应该用于哪个模块的哪种注意力。

Padding Mask。只要有padding的地方,都会用到padding mask,所以Encoder和Decoder都有padding mask。

  • 因为编码时不需要对当前时刻之后的信息进行掩盖,任何位置的信息都可以被任何位置的单词获取。所以编码器的掩码就只是padding mask。在自注意力中会用到。
  • 对于解码器来说:
    • 在交叉注意力中会用到padding mask。
    • 在自注意力中会用到padding mask。

Sequence Mask(Attention Mask)

  • 解码器 的 cross-attention不需要Sequence Mask。因为编码器的输出作为K和V,已经知道了序列的所有信息。

  • 在解码器的Self-Attention里面会用到Sequence Mask。在 Decoder 中的 Self-attention 中:掩蔽的作用是,防止解码器在当前时间步预测时 ,"偷看 "目标句余下几个时间步的部分。所以对于 decoder 的 self-attention里面使用到的 scaled dot-product attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。

实际上在交叉自注意力中,如果我们想限制 Decoder 只能获取某一部分的 Encoder 信息,即 memory bandwidth,也可以使用mask。PyTorch里面就有memory mask,但一般场景下,我们允许 Decoder 获取全部的 Encoder 信息,所以 memory mask 不常用到。

4.2 变量说明

在代码中,有两个关于掩码的变量:src_mask和tgt_mask。Encoder只会看src_mask。Decoder会看src_mask和tgt_task。src_mask就是Padding Mask,而tgt_mask是包含了padding mask和sequence mask的融合mask。

Batch类的代码中设定掩码有两步,在这两步设定之后tgt_mask就是融合掩码。这两步分别是:

  • 第一步:设定padding mask;
  • 第二步,设定padding mask限定之下的sequence mask;

具体代码是:

def make_std_mask(tgt, pad):     "Create a mask to hide padding and future words."     # 一定要注意,这里有两步     tgt_mask = (tgt != pad).unsqueeze(-2) # 第一步,设定padding mask     tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(         tgt_mask.data     ) # 第二步,设定padding mask限定之下的sequence mask     return tgt_mask 

src_mask的形状是(batch size, 1, 1, seq_length),这是因为:

  • src要对句子中的填充词做mask,所以只需要在最后一维做掩码就行了。即其实用一个向量就够了。
  • 因为所有head的mask都一样,所以第二维是1,masked_fill 时使用 broadcasting 就可以了。
  • 这里是 self-attention 的mask,所以每个时刻都可以 attend 到所有其它时刻,所以第三维也是 1,也使用broadcasting。

tgt_mask形状是(batch size, 1, seq_length, seq_length)。tgt需要斜着进行mask,所以需要一个方阵来进行,这个矩阵代表若干个时刻。

Encoder数据流

我们可以举一个例子,为了简单,我们假设 batch=2,head=2,最大允许的序列长度为5, 第一个序列长度为 3,第二个为 5。分别如下:

  • [<bos>, 你,<eos>,<pad>, <pad>]
  • [<bos>, 你,好,吗,<eos>]

编码器中的掩码只是padding mask。因为padding位置的信息不需要带有权重去干扰有实词位置的embedding表征。掩码形状 为 (2, 1, 1, 5),我们可以用两个向量表示:

  • 第一个向量是$ begin{Bmatrix} 1 & 1 & 1 &0 & 0 end{Bmatrix} $。其含义是:第一个句子前3个是单词,后面2个是填充。而mask就是要对后面2个进行mask。因此本序列的任一时刻可以同前 3 个时刻交互来计算注意力。
  • 第二个向量是$ begin{Bmatrix} 1 & 1 & 1 & 1 & 1 end{Bmatrix} $。其含义是:本序列的任意单词可以同所有时刻的输入进行交互。

在实际运算中,因为是多头,所以对于第一个序列,首先会对两个头进行broadcast,得到如下。

[begin{Bmatrix} x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5 end{Bmatrix} ]

然后会施加掩码,得到

[begin{Bmatrix} x1 & x2 & x3 &-1e9 & -1e9 \ x1 & x2 & x3 &-1e9 & -1e9 end{Bmatrix} ]

对于第二个序列,也是两个头进行广播,在掩码前后序列的内容都是

[begin{Bmatrix} x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5 end{Bmatrix} ]

Decoder数据流

解码器的掩码自注意力中同时需要padding mask 和 sequence mask 的组合来作为 attn_mask。这是因为在解码器模块不仅要考虑padding导致的mask,还要考虑后词偷看问题。:

  • 答案是一起输入的,而实际的部署场景是步进预测的,理论上当前步长是看不到当前步长之后的词的信息的。

  • 答案本身会进行该批次下的统一padding,因此还需要再叠加padding的mask掩码,杜绝padding单词对实词的表征影响。

注意:上述信息仅仅对于训练有效,然而为了保持代码复用,所以推理时候也使用同样的代码。

具体实现就是将两个掩码合并,每个位置取最小值,相当于两个掩码只要有任意一种情况需要被遮蔽则就应该被遮蔽。而 Decoder 的 src-attention 的 mask 形状为 (2, 1, 5, 5)。

第一个序列的mask矩阵是两个mask做与操作,其结果作为attn_mask。第一个是padding mask,第二个是sequence mask。即:

[begin{Bmatrix} 1 & 1 & 1 &0 & 0 \ 1 & 1 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 0 end{Bmatrix} ]

[begin{Bmatrix} 1 & 0 & 0 &0 &0 \ 1 & 1 & 0 & 0 &0 \ 1 & 1 & 1 &0 & 0\ 1 & 1 & 1 & 1 & 0 \ 1 & 1 & 1 & 1 & 1 end{Bmatrix} ]

相与,得到

[begin{Bmatrix} 1 & 0 & 0 &0 &0 \ 1 & 1 & 0 & 0 &0 \ 1 & 1 & 1 &0 & 0\ 1 & 1 & 1 & 0 & 0 \ 1 & 1 & 1 & 0 & 0 end{Bmatrix} ]

第二个序列的mask矩阵两个mask相加作为attn_mask。因为是5个单词,所以padding mask是全1。全1矩阵再与三角矩阵做与操作,得到如下。

[begin{Bmatrix} 1 & 0 & 0 &0 &0 \ 1 & 1 & 0 & 0 &0 \ 1 & 1 & 1 &0 & 0\ 1 & 1 & 1 & 1 & 0 \ 1 & 1 & 1 & 1 & 1 end{Bmatrix} ]

实际运算中,对于第一个序列

[ begin{Bmatrix} x1 & x2 & x3 &<pad> &<pad> \ x1 & x2 & x3 &<pad> &<pad> \ x1 & x2 & x3 &<pad> &<pad>\ x1 & x2 & x3 &<pad> &<pad> \ x1 & x2 & x3 &<pad> &<pad> end{Bmatrix} ]

掩码之后得到

[ begin{Bmatrix} x1 & -1e9 & -1e9 &-1e9 &-1e9 \ x1 & x2 & -1e9 &-1e9 &-1e9 \ x1 & x2 & x3 &-1e9 &-1e9\ x1 & x2 & x3 &-1e9 &-1e9 \ x1 & x2 & x3 &-1e9 &-1e9 end{Bmatrix} ]

对于第二个序列

[begin{Bmatrix} x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5 \ x1 & x2 & x3 &x4 & x5end{Bmatrix} ]

掩码之后得到

[begin{Bmatrix} x1 & -1e9 & -1e9 &-1e9 & -1e9 \ x1 & x2 & -1e9 &-1e9 & -1e9 \ x1 & x2 & x3 &-1e9 & -1e9 \ x1 & x2 & x3 &x4 & -1e9 \ x1 & x2 & x3 &x4 & x5end{Bmatrix} ]

4.3 使用

从掩码角度出发,训练和推理的最大不同之处在于每个时间步的输入区别。训练过程中每个时间步的输入是全部目标序列。推理过程中每个时间步的输入,是直到当前时间步所产生的整个输出序列。

为了在训练时候模拟实际推理的效果,需要借助掩码把后面单词的信息隐藏掉,以是确保解码器只能关注到它之前已经生成的词,而不能看到未来的词。此逻辑是为了训练特殊打造,因为训练使用Teacher Forcing模式,需要让前面的token不能观察到后面token的信息。虽然推理时候所有输入都是已知输入,可以互相看到,不需要掩码,但是为了保持一致,也保留了此处代码和模型结构。

训练

接下来我们来追溯一下训练时候的 mask 是怎么来的。我们最终构建的模块是 EncoderDecoder 类的实例,编码器的参数是src_mask,解码器的参数是src_mask和tgt_mask。

class EncoderDecoder(nn.Module):     def forward(self, src, tgt, src_mask, tgt_mask):         "Take in and process masked src and target sequences."         return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)      def encode(self, src, src_mask):         return self.encoder(self.src_embed(src), src_mask)      def decode(self, memory, src_mask, tgt, tgt_mask):         return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 

我们接着深入到解码器中看看其参数。在DecoderLayer类的forward()函数可以看到:

  • 自注意力机制使用的是tgt_mask,作用是对目标语言做 padding mask。
  • 交叉注意力机制使用src_mask,作用是对目标语言做sequence mask。

在多层 Transformer 的解码过程中,每个 Decoder 在交叉注意力中所使用的 Memory 都是同一个。

class DecoderLayer(nn.Module):     "Decoder is made of self-attn, src-attn, and feed forward (defined below)"     def forward(self, x, memory, src_mask, tgt_mask):         "Follow Figure 1 (right) for connections."         m = memory         # 目标语言的自注意力, 这里 mask的作用就是用到上面所说的 softmax 之前的部分         x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))         # m 是encoder的输出,x是decoder第一部分的输出,因为上面一部分的输出中, 未被预测单词的 query 其实是 0(padding), 在这里可以直接使用 src_mask         x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))         # 最后是两个线形层,          return self.sublayer[2](x, self.feed_forward) 

最终进入注意力函数attention()中,这里不再赘述。

推理

对于推理,只有 prefill 阶段需要 mask,用了 kv cache 优化的 decode 阶段不需要 mask 操作。在prefill时, 只有源语言输入的 Batch,因此在 class Batch 中, trg 为 None。从下面代码可以看出来,预测过程的 Attention Mask 设置了padding mask。

def example_simple_model():     V = 11     criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)     model = make_model(V, V, N=2)      optimizer = torch.optim.Adam(         model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9     )     lr_scheduler = LambdaLR(         optimizer=optimizer,         lr_lambda=lambda step: rate(             step, model_size=model.src_embed[0].d_model, factor=1.0, warmup=400         ),     )      batch_size = 80     for epoch in range(20):         model.train()         run_epoch(             data_gen(V, batch_size, 20),             model,             SimpleLossCompute(model.generator, criterion),             optimizer,             lr_scheduler,             mode="train",         )         model.eval()         run_epoch(             data_gen(V, batch_size, 5),             model,             SimpleLossCompute(model.generator, criterion),             DummyOptimizer(),             DummyScheduler(),             mode="eval",         )[0]      # 在这里进行配置         model.eval()     src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])     max_len = src.shape[1]     src_mask = torch.ones(1, 1, max_len) # padding mask          # 这里调用到     print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0)) 

我们直接来看预测过程中的 decoder 的实现。

def greedy_decode(model, src, src_mask, max_len, start_symbol):     memory = model.encode(src, src_mask)     # memory 是 encoder 的中间结果     batch_size = src.shape[0]     ys = torch.ones(batch_size, 1).fill_(start_symbol).type_as(src)     # 预测句子的初始化     for i in range(max_len-1):         out = model.decode(memory, src_mask, ys, transformer.subsequent_mask(ys.size(1)).type_as(src))         # ys 的维度是 batch_size * times, 所以target_mask 矩阵必须是 times * times         # 根据 decoder 的训练步骤, 这里的 out 输出就应该是 batch_size * (times+1) 的矩阵         prob = model.generator(out[:, -1])         # out[:, -1] 这里是最新的一个单词的 embedding 向量         # generator 就是产生最后的 vocabulary 的概率, 是一个全连接层         _, next_word = torch.max(prob, dim = 1)         # 返回每一行的最大值, 并且会返回索引         next_word = next_word.unsqueeze(1)         ys = torch.cat([ys, next_word.type_as(src)], dim=1)         # 将句子拼接起来     return ys 

上面代码的 transformer.subsequent_mask(ys.size(1)).type_as(src) 这一部分就很好的解释了 target_mask 矩阵的构造方法。

我们再看看Decoder的forward函数,发现还是进入到了attention()。但这次输入的x是tgt。

class Decoder(nn.Module):     "Generic N layer decoder with masking."      def __init__(self, layer, N):         super(Decoder, self).__init__()         self.layers = clones(layer, N)         self.norm = LayerNorm(layer.size)      def forward(self, x, memory, src_mask, tgt_mask):         for layer in self.layers:             x = layer(x, memory, src_mask, tgt_mask)          return self.norm(x) 

4.4 PyTorch

如果我们去看 Pytorch Transformer 的文档,会发现有六种掩码矩阵。我们可以把六种掩码矩阵分成两类。

第一类叫做 attention mask,用来定义输入序列的哪些部分被允许关注。对应了哈佛代码中的sequence mask。

  • source mask:Encoder 中的自注意力掩码,形状为 (source_len, source_len)
  • target mask:Decoder 中因果自注意力掩码,形状为 (target_len, target_len)
  • memory mask:交叉自注意力中用到的掩码矩阵,形状为 (target_len, source_len)。此掩码用于解码器中的交叉注意力,主要是为了综合编码器和解码器中的padding。交叉注意力中的Q来自解码器,需要和编码器中的key-value sets求相关性矩阵,这里就不涉及窥探未来信息的问题了,只需要考虑padding。

第二类叫做 key_padding mask,分别在 source seq,target seq,memory seq(即 Encoder 的输出序列) 中标注 token 的位置,从而让这些不被关注。对应了哈佛代码中的padding mask。

  • src_key_padding_mask: 形状为 (batch_size, source_len)
  • tgt_key_padding_mask: 形状为 (batch_size, target_len)
  • memory_key_padding_mask: 形状为 (batch_size, source_len)

从下面这个例子中可以看到,attention maskkey_padding mask是“各司其职”的。

# 生成一个下三角矩阵,即为 target mask def generate_square_subsequent_mask(sz):     mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)     mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))     return mask  # 或者等价地: def generate_square_subsequent_mask(sz):     mask = torch.triu(torch.full((sz, sz), float('-inf'), , device=DEVICE)), diagonal=1)  def create_mask(src, tgt):     src_seq_len = src.shape[0]     tgt_seq_len = tgt.shape[0]      # attention mask 部分     tgt_mask = generate_square_subsequent_mask(tgt_seq_len)     src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool) 	     # key_padding mask 部分     src_padding_mask = (src == PAD_IDX).transpose(0, 1)     tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)     return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask 

其实,就是把哈佛代码中的掩码给细化了。我们总结其联系如下。

探秘Transformer系列之(11)--- 掩码

4.5 小结

下面流程图梳理了代码逻辑,可以看到,Encoder只会看src_mask,Decoder则会看src_mask和tgt_task。

探秘Transformer系列之(11)--- 掩码

我们再从模型架构角度给出交互数据流图,具体如下。

探秘Transformer系列之(11)--- 掩码

0x05 进阶

5.1 Sample Packing和mask

当上下文长度增加时,batch对齐问题就会浮出水面。长文本训练在批大小大于一的情况下可能会因为 Pad tokens 浪费非常多的空间,这是因为长文本往往在长度分布上可以跨越多个数量级。下面的图是一个例子。

探秘Transformer系列之(11)--- 掩码

比如一个 4K 的样本和一个 64K 的样本可能会出现在同一个 batch 中。这种情况下 4K 的样本后面会被使用 pad token 补全到 batch 中最长的样本的长度。这意味着可能一个 4K 的样本被填充了 60K 的长度。造成了很大的浪费。

定义

所幸现在的精调框架大多能够通过 Sample Packing 技术解决这个问题。Sample Packing 实际上去除了batch size的概念。一个包含 3 样本的 batch 现在被拼接成一个更长的单个序列。三个样本头尾相接成一个序列,同时attention mask也会对应得发生改变,来防止同一个序列中的不同样本相互影响。这样的好处就是再也没有 pad token:一个输入可能包含 2 个长的样本,也可能包含 100 个短样本。

探秘Transformer系列之(11)--- 掩码

不过实践中,LongAlign 论文提到,长的样本和极短的样本出现在同一个 batch 中可能会影响模型收敛,为了解决这个问题,一般会在训练时让长度相近的样本出现在同一个batch中。

Attention mask

以Megatron-LM(DeepSpeed-Megatron)为例,预训练通常包含很多不同的数据集,每个数据集又包含许多 Document。为了提升训练效率,在实际训练的时候一个 Sample(Sequence)里面可能会包含多个不同的 Document(Sample Packing)。比如 8K 的预训练 Sequence Length,则一个 Sample 可以包含 8 个 1K 的 Document。

对于单个 Document 而言,Decoder Only 的 GPT 模型具有 Causal 特性,也就是每个 Token 不能看到之后的 Token,因此在实际训练中需要添加 Attention Mask。这种情况下 Attention Mask 是一个标准的下三角矩阵(Causal Mask)。也就是绿色部分为 1,其他部分为 0:

探秘Transformer系列之(11)--- 掩码

如果一个 Sample 里包含多个样本,则 Attention Mask 矩阵需要变成如下图所示的块对角矩阵形式(Block Diagonal Mask)。比如 Sequence Length 为 16,4 个 Document 的长度分别为 3,4,5,4,则对应 Attention Mask 矩阵如下图所示,对角线上的 4 个矩阵(红框)都是标准的下三角矩阵。按照这种方式可以保证和 4 个 Document 单独作为 Sample 训练是等价的:

探秘Transformer系列之(11)--- 掩码

论文“LongAlign: A Recipe for Long Context Alignment of Large Language Models”中讨论了部分 Sample Packing 相关问题。如下图左所示,Sequence 的长度各不相同,从 0 - 60K,如果采用 Naive Batching 方式,会导致明显的 Bubble 问题。为了解决效率和效果问题,作者提出了 3 种解决方案:Packing、Loss Weighting 和 Sorted Batching。

下图右侧就是我们之前介绍的 Sample Packing:将不同的 Sample 拼接在一个 Sequence 里,并且保证尽可能接近 Max Sequence Length,末尾的部分 Token 进行 Padding。然后通过 Block Diagonal Attention Mask 来区别不同的 Sample,以避免 Sample 之间的交叉污染,也就是 Document Level Attention。

探秘Transformer系列之(11)--- 掩码

策略

在论文“Enhancing Training Efficiency Using Packing with Flash Attention”中,作者总结了不同 Packing 策略、Mask 方式及与 FlashAttention 结合的优势。

如下图所示,作者分析了不同的 Packing 方案以及它们的影响,具体包含如下几种方式:

  • RandomSampling + Padding:最传统的随机采样,然后 Padding 的方式。存在冗余计算,并且占比很高。
  • GroupByLength+Padding:先排序,然后尽量保证每个 Batch 中的序列长度接近。可以减少 Padding 的占比。
  • RandomSampling + PosID:随机采样,但是不 Padding,而是通过 PosID 支持变长序列。几乎没有冗余计算,但可能存在明显的负载不均衡(计算量)。
  • FixedLengthPacking:随机采样,随机 Packing,并且最后一个 Sample 可能截断,保证填满 Max Sequence Length。没有区分不同 Sample,也就是 Causal Mask,没有冗余计算,并且负载很均衡。
  • FixedLengthPacking + PosID:相比 FixedLengthPacking 多了 PosID,也就是可以区分不同 Sample,对应 Block Diagonal Mask。但依然会存在末尾截断,并且可能负载不均衡。
  • MultiPack + PosID:使 Sequence 中的数据尽量接近 Batch 的 Max Sequence Length,降低 Sequence 中的长度不均衡,可以参考 GitHub - imoneoi/multipack_sampler: Multipack distributed sampler for fast padding-free training of LLMs。需要对数据进行排序。
  • SortedPacking + PosID:通过排序,使同一个 Batch 中的计算复杂度尽量接近。可以尽可能降低计算负载不均衡问题。
  • RandomPacking + PosID:与 FixedLengthPacking + PosID 相比主要的区别就是最后一个 Sample 不截断,可能存在部分 Bubble。

探秘Transformer系列之(11)--- 掩码

5.2 功用

有研究表明,纯自注意力机制在深度增加时会经历秩崩溃,限制了模型的表达能力和进一步利用模型深度的能力。然而,现有的关于秩崩溃的文献大多忽略了Transformer中的其他关键组件,这些组件可能缓解秩崩溃问题。论文“On the Role of Attention Masks and LayerNorm in Transformers”对自注意力机制下的秩崩溃进行了综合分析,考虑了注意力掩码和层归一化(LayerNorm)的影响。具体来说,作者发现尽管纯掩码注意力仍然会指数级地崩溃到一个秩为1的子空间,但稀疏或局部掩码注意力可以证明减缓崩溃速率。在LayerNorm的情况下,作者首先展示了对于某些类别的值矩阵,秩为1的子空间崩溃仍然会指数级发生。然而,通过构建非平凡的反例,作者证明了在适当选择值矩阵的情况下,一类通用的序列可能不会收敛到秩为1的子空间,并且带有LayerNorm的自注意力动态可以同时拥有从1到满秩的任意秩的平衡点。作者的结果反驳了之前关于LayerNorm在自注意力秩崩溃中不起作用的假设,并表明带有LayerNorm的自注意力构成了一个比最初认为的更具表达力和多功能的非线性动力系统。

创新点

注意力掩码对秩崩溃的影响分析:论文首次系统性地分析了注意力掩码对Transformer中秩崩溃现象的影响。通过引入图论方法,论文证明了在准强连通图的情况下,即使使用稀疏或局部注意力掩码,令牌的秩崩溃仍然会发生,但速率会减缓。这一发现为设计更高效的注意力机制提供了理论基础。

LayerNorm对秩崩溃的缓解作用:作者通过构建非平凡的反例,证明了LayerNorm在某些情况下可以有效缓解令牌的秩崩溃问题。在适当选择值矩阵的情况下,带有LayerNorm的自注意力动态可以同时拥有从1到满秩的任意秩的平衡点。

掩码注意力

作者首先分析不带LayerNorm的情况,并关注注意力掩码的影响。

探秘Transformer系列之(11)--- 掩码

上述结果表明,在纯自注意力下,只要序列中存在一个令牌,所有其他令牌都可以在固定层数内直接或间接参与,那么令牌的秩崩溃就会指数级发生。特别是,这个结论可以推广到更一般的注意力模式类别:注意力模式只需要是准强连通的,这意味着对于实践中使用的各种注意力掩码,包括GPT系列中使用的因果掩码,或许多高效Transformer模型中部署的稀疏注意力模式,都会存在秩崩溃现象。

作者讨论了以下几个有趣的含义。

  • 局部 vs. 全局注意力 指数速率((1-epsilon^r)^{1/r})在图半径r上是单调的。这意味着对于半径较大的图,秩崩溃应该较慢。这说明使用局部注意力模式不仅使注意力计算更高效,而且隐式地缓解了秩崩溃问题。
  • 聚焦 vs. 均匀注意力 此外,指数速率在(epsilon)上单调递减,这意味着(epsilon)越小,秩崩溃越慢。可以将(epsilon)解释为注意力在可达token之间的“聚焦”程度,因为(epsilon)在注意力均匀分布在可达token时达到最大值。除了应用注意力掩码和限制可达令牌的数量外,控制注意力聚焦程度的另一种方法是通过温度项(d_{QK})。较大的(d_{QK})值会使可达令牌之间的注意力分配更加均匀,从而使秩崩溃在各层之间更快发生。
  • 秩崩溃与通用逼近能力的权衡 最后,对于强连通图,上述结果还揭示了通用函数逼近能力与秩崩溃速率之间的权衡。有研究表明,带有强连通图掩码的Transformer是sequence-to-sequence函数通用逼近器,然而,对于掩码(g),它们需要至少(g)的直径那么多的层数才能实现完整的sequence-to-sequence的函数逼近属性。这意味着直径较小的掩码在函数逼近能力方面更高效,但它们更容易发生秩崩溃。

带LayerNorm的掩码注意力

我们接下来看看带LayerNorm的掩码注意力会有什么性质。

论文作者首先展示一个负面结果,表明对于某些类别的值矩阵,如果初始时所有token对的余弦相似度都是非负的,那么只要(g)是准强连通的,仍然会发生token以指数级的速度崩溃到一个共同向量,即秩崩溃。但如果(g)是因果图(causal graph),掩码将只有一个中心节点,上界会更宽松,这表明因果掩码在缓解秩崩溃速率方面相对于全掩码具有优势。

然后,作者展示了反例,对于一类通用的输入序列,当仅使用LayerNorm时,token会收敛到一个均衡状态,在该状态下不会发生秩崩溃。然后,作者展示了一个普适性的结果,表明在LayerNorm和适当选择值矩阵的情况下,自注意力动态可以拥有从1到满秩的任意秩的平衡点。

0xFF 参考

LLM 预训练语料、预处理和数据集索引、加载总结 AI闲谈

FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention by Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong

Sample Packing:长序列 LLM 训练的 Attention 问题及优化

https://blog.csdn.net/zhaohongfei_358/article/details/125858248

Transformer系列:图文详解Decoder解码器原理 xiaogp

LongAlign: A Recipe for Long Context Alignment of Large Language Models

NIPS 2024 | 注意力掩码和LayerNorm在Transformer中的作用 [CV技术指南]

On the Role of Attention Masks and LayerNorm in Transformers

transformer中三种mask的使用 初夏

【深度学习】Transformer中的mask机制超详细讲解 Articoder

Transformer Encoder/Decoder结构中的掩码Mask介绍? [AIGC小白入门记]

LongAlign

发表评论

评论已关闭。

相关文章