numpy.einsum 如何理解和运用?
date
May 2, 2022
Last edited time
Mar 27, 2023 09:01 AM
status
Published
slug
numpy.einsum如何理解和运用?
tags
Algorithm
summary
type
Post
origin
Field
Plat
简介
einsum(爱因斯坦求和)是 pytorch、numpy 中一个十分优雅的方法,如果利用得当,可完全代替所有其他的矩阵计算方法,不过这需要一定的学习成本。本文旨在详细解读 einsum 方法的原理,并给出一些基本示例。
问题引入
在线性代数中,我们最多涉及的是二阶及以下的张量.在这种情况下,纸面上可以很方便地写出低阶张量的矩阵形式,高阶的张量,它们的坐标就没法用矩阵表示.我们当然可以把矩阵拓展为立体阵等概念,但随着阶数上升,这种表示法的复杂程度几何级增加;我们也可以使用张量词条中所提过的向量矩阵的方法,比起立体阵要清楚一些,但套娃式的表达方式也对理解一个张量的性质造成了障碍.
爱因斯坦求和约定正是为了简洁地表达高阶张量的坐标运算而存在的.
一、矩阵乘法
假设 矩阵大小分别是 和 ,矩阵乘法的定义如下:
其中,
python
用循环实现:结果为:
二、爱因斯坦求和法
爱因斯坦求和是一种对求和公式简洁高效的记法,其原则是当变量下标重复出现时,即可省略繁琐的求和符号。
比如求和公式:
其中变量 和变量 的下标重复出现,则可将其表示为:
由此我们可以将上述矩阵运算化简为:
进一步地,我们可以得到矩阵乘法的一个抽象
einsum 的原理
一、具体原理
einsum 方法正是利用了爱因斯坦求和简洁高效的表示方法,从而可以驾驭任何复杂的矩阵计算操作。基本的框架如下:
上述操作表示矩阵 与矩阵 的点积。输入的参数分为两部分:
- 前面表示计算操作的指令串,
- 后面是以逗号隔开的操作对象(数量需与前面对应)。
其中在计算操作表示中,
- “
->
” 左边是以逗号隔开的下标索引,重复出现的索引即是需要爱因斯坦求和的;
- “
->
” 右边是最后输出的结果形式。
以上式为例,其计算公式为:, 其等价于矩阵 与 的点积。
在矩阵之间的运算中,下标可以分为两类:
- 自由标 (Free index),也就是在输入和输出端都出现的下标
- 哑标 (Summation index),在输入端出现但输出端没有出现的下标
矩阵运算中所有参与运算的下标都被包含在次定义中。
以上述矩阵 的乘法过程为例:
可以看出,这与上述通过循环方式得出的结果一致。在
ij,jk -> ik
的例子中, i,j
是自由标,k
是哑标。二、计算准则
- 两个不同矩阵相乘,哑标维度需要逐元素相乘并求和,自由标保留
- 自由标可在输出中以任意顺序出现,但只能出现一次
这是两条基本准则,具体的计算场景可以参考下文实例。
典型计算场景
利用 einsum 求解张量运算主要分为单操作数和多操作数的情况,我们分别讨论,并力图转化为循环形式便于明晰求解过程。
1. 单操作数
1.1 矩阵的迹:
迹(trace)指的是方针的对角线元素。 einsum 表示为:
结果:
1.2 矩阵转置
矩阵的转置(transpose)指矩阵行列互换。 einsum 表示为:
结果:
1.3 矩阵求和
按行还是列求和,取决于最终保留的下标:
结果:
2. 多操作数
2.1 向量内 / 外积
结果:
2.2 矩阵乘法
矩阵乘法最典型的形式为:
它的循环形式可以展开为:
当 k 也作为自由标被保留下来的时候,情况稍有不同:
此时,上式对应的循环形式应该为:
此时,k 不在作为哑标被求和,在输出中也会保留该维度,并且按照 i,j,k 的顺序排列输出维度。
多个矩阵的连乘可以按照同样的方式进行:
3. 广播乘法
广播方式比较复杂,这里仅举一个常见例子:
在 Transformer 的 self-attention 机制中,对与子矩阵 QKV 需要进行 Multi-Head 操作, 这里假设:
转化为多头后,维度变为:
,可以得到 矩阵的张量表示:
通过这种方法,可以轻松完成多头下的自注意力乘积操作。
实际上,上述操作与下面的过程也是等价的:
另外,广播乘法有一个更简洁的形式:
‘…’ 指代任意多个维度,这在处理 batch 和图像中的多通道时尤为有效。