张量链式法则(下篇):揭秘Transpose、Summation等复杂算子反向传播,彻底掌握深度学习求导精髓!

本文首发于本人的微信公众号,链接:https://mp.weixin.qq.com/s/eEDo6WF0oJtRvLYTeUYxYg

摘要

本文紧接系列的上篇,介绍了 transpose,summation,broadcast_to 等更为复杂的深度学习算子的反向传播公式推导。

写在前面

本系列文章的上篇介绍了张量函数链式法则公式,并以几个简单的算子为例子展示了公式的使用方法。本文将继续以更复杂的算子为例子演示公式的使用方法,求解这些算子的反向传播公式也是我研究张量函数链式法则的目的:因为对于 transpose,broadcast_to 这类会根据传入的参数改变输出张量维度数量的算子,常规的矩阵链式法则公式已无法解决。

常见算子的反向传播推导(下半部分)

复习一下

张量函数链式法则的公式为:

[nabla_{lambda_1 lambda_2 cdots lambda_n} = sum_{substack{mu_1 in [1, e_1] \ mu_2 in [1, e_2] \ vdots \ mu_m in [1, e_m]}} g_{mu_1 mu_2 cdots mu_m} frac{partial}{partial x_{lambda_1 lambda_2 cdots lambda_n}} f_{mu_1 mu_2 cdots mu_m} ]

求解步骤为:我们首先需要确定各个张量的形状,然后把注意力集中到自变量里的某个元素,写出这个元素的导数表达式,然后再推广到整个导数张量。

接下来我们继续常见算子的反向传播公式推导。

Summation

这个算子是对输入张量沿着某些轴求和,这个算子有一个参数 axes,表示求和的规约轴,例如,对于一个四维张量 (X in mathbb{R}^{d_1 times d_2 times d_3 times d_4}),如果 axes=(2, 3)(F(X) in mathbb{R}^{d_1 times d_4}),是一个二维张量,且 (f_{ij} = sum_{k=1}^{d_2} sum_{l=1}^{d_3} X_{iklj})

由此可见,对于多个轴的 summation 操作其实可以拆分为多次的对于一个轴的 summation,所以我们仅讨论 axes 只有一个轴的公式,对于有多个轴的场景可以将其视为复合函数,通过反复使用该公式来进行扩展。

单轴 Summation 问题描述

所以我们要解决的问题就变成了:函数 (F) 会对张量 (X) 的第 (a) 个维度进行求和,求该函数的反向传播公式。

(注:本文统一以 1 为起始下标,实际编程时 axes 是以 0 为起始下标,这个差异需要注意)

首先确定各个张量的形状,如果自变量 (X in mathbb{R}^{d_1 times d_2 times cdots times d_n}) 是一个 (n) 维张量,那么 (F(X) in mathbb{R}^{d_1 times d_2 times cdots times d_{a-1} times d_{a+1} times cdots times d_n})(n-1) 维张量。

单轴 Summation 问题求解

接下来可以写出每个自变量的导数的表达式:

[nabla_{lambda_1 lambda_2 cdots lambda_n} = sum_{substack{mu_1 in [1, e_1] \ mu_2 in [1, e_2] \ vdots \ mu_{a-1} in [1, e_{a-1}] \ mu_{a+1} in [1, e_{a+1}] \ vdots \ mu_m in [1, e_m]}} g_{mu_1 mu_2 cdots mu_m} cdot frac{partial f_{mu_1 mu_2 cdots mu_m}}{partial x_{lambda_1 lambda_2 cdots lambda_n}} ]

[= sum_{substack{mu_1 in [1, e_1] \ mu_2 in [1, e_2] \ vdots \ mu_{a-1} in [1, e_{a-1}] \ mu_{a+1} in [1, e_{a+1}] \ vdots \ mu_m in [1, e_m]}} g_{mu_1 mu_2 cdots mu_n} cdot frac{partial}{partial x_{lambda_1 lambda_2 cdots lambda_n}} sum_{i=1}^{e_a} x_{mu_1 mu_2 cdots mu_{a-1} i mu_{a+1} cdots mu_n} ]

注意到,当且仅当 (mu_1 = lambda_1, mu_2 = lambda_2, ldots, mu_n = lambda_n) 时,这个表达式值不为 0,且满足上述条件时,只有当 (i = lambda_a) 时,求和表达式值为 1,(i) 为其他值时都为 0,所以这一项的最终结果是 (g_{lambda_1 lambda_2 cdots lambda_n})

所以最终的 (nabla = text{broadcast}(G, a)),即把张量 (G) 在第 (a) 个轴做 broadcast_tobroadcast_to 操作的定义见下文)。

当然,这里实际操作时首先要对 (G)reshape,把因为求和丢掉的轴 unsqueeze 回来,然后再通过 broadcast_to 操作广播到 (X) 的形状,具体可以参考下面的具体代码实现:

a = node.inputs[0] target_dim_num = len(a.shape) grad_new_shape = [] for i in range(target_dim_num):     if i in self.axes:         grad_new_shape.append(1)     else:         grad_new_shape.append(a.shape[i]) return broadcast_to(reshape(out_grad, grad_new_shape), a.shape) 

多轴 Summation 问题求解

接下来讨论 axes 有多个的情形,通过上面的讨论,容易想到:只需要把求和规约掉的多个轴通过 reshape 进行 unsqueeze,然后再进行 broadcast 就行了。

实际情况正是如此,以两个轴为例,这种情况可以认为是两个单轴 summation 操作的复合,在实际进行反向传播时,会先传播到第一个单轴 summation,此时会进行一次 broadcast_to,然后这个结果会作为 (G) 继续传播到第二个单轴 summation,此时又会进行一次 broadcast_to,最终结果等价于把这两次 broadcast_to 放到一起完成。

严格的数学推导这里就不展开了,留作习题自证不难。

所以对于 Summation,最终的导数结果为:

[nabla = text{broadcast_to}(G, X.text{shape}) ]

BroadcastTo

这个算子是对一个张量进行广播操作,也就是把张量的元素在若干个轴上进行“复制”的操作,形成一个更“充实”的张量。numpy,pytorch 等框架在处理形状不同的张量时会自动进行广播操作。例如,(A) 的形状是 ((6, 6, 5, 4))(B) 的形状是 ((6, 5, 4)),在执行 (A odot B) 时,框架会自动在 (B) 的左边补上维度 1,变成 ((1, 6, 5, 4)),然后再执行广播变成 ((6, 6, 5, 4)),然后再做哈达马积。

这里我们同样先讨论只针对一个轴进行 broadcast_to 的情形,多轴的情形同样可以视为多个单轴 broadcast_to 的嵌套。

(注:以下讨论涉及到的参数和实际编程中的参数有差异,实际编程中是直接传入 broadcast_to 之后的形状作为参数)

单轴 BroadcastTo 算子定义

单轴 broadcast_to 算子有两个参数:

  • 参数 a,表示在哪一个轴进行广播,该算子要求自变量在这一维度的大小为 1
  • 参数 b,表示要将这一维度广播到多大

这一算子的形式化的定义为:

  • (X in mathbb{R}^{d_1 times d_2 times cdots times d_n}),是 (n) 维张量,其中 (d_a = 1)(F(X) = text{broadcast_to}(X, a))
  • (F(X) in mathbb{R}^{d_1 times d_2 times cdots times d_{a-1} times b times d_{a+1} times cdots times d_n}),其中 (f_{lambda_1 lambda_2 cdots lambda_n} = x_{lambda_1 lambda_2 cdots lambda_{a-1} 1 lambda_{a+1} cdots lambda_n})

直观上来看就是把 (X) 在第 (a) 维的元素复制了 (b) 份。

单轴 BroadcastTo 问题求解

首先可以确认,(X)(nabla) 形状相同,为 (mathbb{R}^{d_1 times d_2 times cdots times d_{a-1} times 1 times d_{a+1} times cdots times d_n})(G)(F(X)) 的形状相同,为 (mathbb{R}^{d_1 times d_2 times cdots times d_{a-1} times b times d_{a+1} times cdots times d_n})

写出 (nabla) 的表达式可得:

[nabla_{lambda_1 lambda_2 cdots lambda_n} = sum_{substack{ mu_1 in [1, e_1] \ mu_2 in [1, e_2] \ vdots \ mu_a in [1, b] \ vdots \ mu_n in [1, e_n]}} g_{mu_1 mu_2 cdots mu_n} cdot frac{partial f_{mu_1 mu_2 cdots mu_n}}{partial x_{lambda_1 lambda_2 cdots lambda_n}} ]

(F) 的定义式代入,原式子可写作:

[sum_{substack{ mu_1 in [1, e_1] \ mu_2 in [1, e_2] \ vdots \ mu_a in [1, b] \ vdots \ mu_n in [1, e_n]}} g_{mu_1 mu_2 cdots mu_n} cdot frac{partial x_{mu_1 mu_2 cdots mu_{a-1} 1 mu_{a+1} cdots mu_n}}{partial x_{lambda_1 lambda_2 cdots lambda_n}} ]

注意到,只有当 (mu_1 = lambda_1, mu_2 = lambda_2, ldots, mu_{a-1} = lambda_{a-1}, mu_{a+1} = lambda_{a+1}, ldots, mu_n = lambda_n) 时,求和式不为 0,所以这个式子可以进一步化简为:

[sum_{mu_a in [1, b]} g_{lambda_1 lambda_2 cdots lambda_{a-1} mu_a lambda_{a+1} cdots lambda_n} ]

这个表达式的值恰好就等于张量 (G)(a) 轴做 Summation,所以有:

[nabla = text{Summation}(G, a) ]

多轴 BroadcastTo 问题求解

和 Summation 类似,多轴情形下只需要对所有广播过的轴做 Summation 即可,由此可得,多轴情形下:

[nabla = text{Summation}(G, (a_1, a_2, ldots, a_m)) ]

其中 (a_1, a_2, ldots, a_m) 是所有经过广播的轴的编号,具体可以参考以下代码实现:

old_shape = node.inputs[0].shape new_shape = self.shape sum_axes = [] for i in range(len(new_shape)):     if i >= len(old_shape) or (old_shape[i] == 1 and new_shape[i] != 1):         sum_axes.append(i)  return reshape(summation(out_grad, tuple(sum_axes)), old_shape) 

Reshape

顾名思义,这个算子的作用就是改变张量的形状。numpy 对于这个操作的描述是:在不改变数组内容的情况下为数组赋予新的形状。可以认为 numpy 存储的多维张量本质上是一个连续的一维数组,形状只是我们看这个数组的一个视角,以二维张量为例,假设这个一维数组是 ([1,2,3,4,5,6]),如果以 (2 times 3) 矩阵的视角去看,那就会是:

[begin{bmatrix} 1 & 2 & 3 \ 4 & 5 & 6 end{bmatrix} ]

如果以 (6 times 1) 的矩阵视角去看,那就会是:

[begin{bmatrix} 1 \ 2 \ 3 \ 4 \ 5 \ 6 end{bmatrix} ]

Reshape 问题求解

这里我们可以猜一下,以三维张量为例,(nabla, X in mathbb{R}^{e_1 times e_2 times e_3})(G, F(X) in mathbb{R}^{d_1 times d_2 times d_3}),其中 (e_1 times e_2 times e_3 = d_1 times d_2 times d_3)

注意到 (nabla)(G) 的元素数量相同,只是形状不同,那只需要进行一次 reshape 即可。

事实正是如此,对于 (F(X) = text{reshape}(X, text{new_shape})),其反向传播导数:

[nabla = text{reshape}(G, X.text{shape}) ]

这里具体的数学推导就不再赘述了,留作习题供读者练习。

(提示:可以考虑定义一个辅助函数,将原来轴的参数映射到新的轴上的参数)

Transpose

这一算子的定义是做转置,二维矩阵的转置很显然,就是行列互换。推广到 (n) 维张量,就是选择两个轴,然后在这两个轴上做互换。

(注:这里的 transpose 是 CMU Homework1 里面定义的,而非 numpy 里的定义,这里只会转置两个轴,但是这里推导得到的结果可以轻易推广到多轴的情形)

Transpose 形式化定义

  • 这一算子有 2 个参数 ab,表示需要转置的两个轴
  • (X in mathbb{R}^{d_1 times d_2 times cdots times d_n}),是 (n) 维张量
  • (F(X) in mathbb{R}^{d_1 times d_2 times cdots times d_{a-1} times d_b times d_{a+1} times cdots times d_{b-1} times d_a times d_{b+1} times cdots times d_n}) 也是 (n) 维张量,只是第 (a) 维和第 (b) 维的大小互换了
  • 且其中:
[f_{lambda_1 lambda_2 cdots lambda_{a-1} lambda_b lambda_{a+1} cdots lambda_{b-1} lambda_a lambda_{b+1} cdots lambda_n} = f_{lambda_1 lambda_2 cdots lambda_{a-1} lambda_a lambda_{a+1} cdots lambda_{b-1} lambda_b lambda_{b+1} cdots lambda_n} ]

Transpose 问题求解

这里也很容易才到,对 (G) 做同样的转置即可得到,这里同样不展开赘述了,留作习题供读者练习。

(提示:同样可以考虑定义映射轴的辅助函数来解决)

MatMul

这一算子是矩阵乘法,二维矩阵的公式已经在上一篇文章里给出,这里主要补充一下 batch 模式下的矩阵乘法。根据 numpy 里的定义,进行 MatMul 的两个张量 (X)(Y) 可以是两个高维的张量,例如,当 (X) 的形状为 ((6, 6, 5, 3))(Y) 的形状为 ((6, 6, 3, 4)) 时,会把 (X) 视为是 36 个 (5 times 3) 矩阵按照 (6 times 6) 的格式排列,然后把 (Y) 视为 36 个 (3 times 4) 的矩阵按照 (6 times 6) 排列,最后将两个大矩阵中对应位置的两个小矩阵做矩阵乘法,最终会得到 36 个 (5 times 4) 的小矩阵,组成一个形状为 ((6, 6, 5, 4)) 的张量。

这一操作同样支持广播,即:如果 (X) 形状为 ((6, 6, 5, 3))(Y) 的形状为 ((3,4)),那么最终结果会是形状为 ((6, 6, 5, 4)) 的张量,即 (X) 的 36 个小矩阵每一个都和 (Y) 做矩阵乘法。

这种情形下,如果记单个矩阵乘法的函数为 MatMul,批量矩阵乘法函数为 MatMul_Batch,那么此时 MatMul_Batch 实际上是 MatMul(X, broadcast_to(Y, X.shape)),所以在处理 MatMul_Batch(Y) 求导时,需要考虑到这里实际上是嵌套了一层广播的,而广播的反向传播是做 Summation,所以在套用单矩阵 MatMul 的反向传播公式之后还需要做一个 Summation 将形状变回和 (Y) 相同的形状,具体过程可以参考如下的代码实现:

(注:理论上是需要先做 Summation 再做 Matmul 的反向传播,但是先做 Summation 和后做是等价的,为了代码实现方便就统一放到后面来做了)

a, b = node.inputs a_grad, b_grad = matmul(out_grad, transpose(b)), matmul(transpose(a), out_grad)  if len(a_grad.shape) > len(a.shape):     sum_axes = tuple((i for i in range(len(a_grad.shape) - len(a.shape))))     a_grad = summation(a_grad, sum_axes)  if len(b_grad.shape) > len(b.shape):     sum_axes = tuple((i for i in range(len(b_grad.shape) - len(b.shape))))     b_grad = summation(b_grad, sum_axes)  return a_grad, b_grad 

一些剩下的简单算子

接下来放一些简单算子的反向传播公式,这里就只给出结果而省略推导过程了。

Negate

这个算子是把张量中所有元素取相反数,很显然:

[nabla = -G ]

Log

这个算子是对张量中所有元素取自然对数,很显然:

[nabla = frac{G}{X} ]

Exp

这个算子是对张量中所有元素过一次指数函数 (y = e^x),很显然:

[nabla = G odot exp(X) ]

EWisePow

这个算子接收 2 个相同形状的自变量 (X)(Y)(如果形状不同会进行广播到相同形状),对于 (X) 里的每一个 (x),取 (Y) 对应位置上的元素 (y),做 (x^y)

很显然:

[nabla^X = G odot Y odot text{EWisePow}(X, Y - 1) ]

[nabla^Y = G odot text{EWisePow}(X, Y) odot log(X) ]

发表评论

评论已关闭。

相关文章

当前内容话题