CV+Transformer 之 Swin Transformer

date
Apr 26, 2022
Last edited time
Mar 27, 2023 08:59 AM
status
Published
slug
CV+Transformer之Swin_Transformer
tags
DL
CV
summary
转载
type
Post
Field
Plat

前言

自从 Transformer[1] 在 NLP 任务上取得突破性的进展之后,业内一直尝试着把 Transformer 用于在 CV 领域。之前的若干尝试,例如 iGPT[2],ViT[3] 都是将 Transformer 用在了图像分类领域,目前这些方法都有两个非常严峻的问题
  1. 受限于图像的矩阵性质,一个能表达信息的图片往往至少需要几百个像素点,而建模这种几百个长序列的数据恰恰是 Transformer 的天生缺陷;
  1. 目前的基于 Transformer 框架更多的是用来进行图像分类,理论上来讲解决检测问题应该也比较容易,但是对实例分割这种密集预测的场景 Transformer 并不擅长解决。
本文提出的 Swin Transformer [4] 解决了这两个问题,并且在分类,检测,分割任务上都取得了 SOTA 的效果。Swin Transformer 的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的 backbone,并且大多数在 CNN 网络中常见的超参数在 Swin Transformer 中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。该网络架构的设计非常巧妙,是一个非常精彩的将 Transformer 应用到图像领域的结构,值得每个 AI 领域的人前去学习。
 
在 Swin Transformer 之前的 ViT 和 iGPT,它们都使用了小尺寸的图像作为输入,这种直接 resize 的策略无疑会损失很多信息。与它们不同的是,Swin Transformer 的输入是图像的原始尺寸,例如 ImageNet 的 224*224。另外 Swin Transformer 使用的是 CNN 中最常用的层次的网络结构,在 CNN 中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在 Swin Transformer 中也是满足的。Swin Transformer 的这种层次结构,也赋予了它可以像 FPN[6],U-Net[7] 等结构实现可以进行分割或者检测的任务。Swin Transformer 和 ViT 的对比如图 1。
图1
图1
本文将结合它的 pytorch 源码对这篇论文的算法细节以及代码实现展开详细介绍,并对论文中解释模糊的地方具体分析。读完此文,你将完全了解清楚 Swin Transfomer 的结构细节以及设计动机,现在我们开始吧。
swin-transformer-pytorch
berniwalUpdated May 9, 2022

算法详解

1.1 网络框架

Swin Transformer 共提出了 4 个网络框架,它们从小到大依次是 Swin-T,Swin-S, Swin-B 和 Swin-L,为了绘图简单,本文以最简单的 Swin-T 作为示例来讲解,Swin-T 的结构如图 2 所示。Swin Transformer 最核心的部分便是 4 个 Stage 中的 Swin Transformer Block,它的具体在如图 3 所示。
从源码中我们可以看出 Swin Transformer 的网络结构非常简单,由 4 个 stage 和一个输出头组成,非常容易扩展。Swin Transformer 的 4 个 Stage 的网络框架的是一样的,每个 Stage 仅有几个基本的超参来调整,包括隐层节点个数,网络层数,多头自注意的头数,降采样的尺度等,这些超参的在源码的具体值如下面片段,本文也会以这组参数对网络结构进行详细讲解。
图2
图2

1.2 Patch Partition/Patch Merging

在图 2 中,输入图像之后是一个 Patch Partition,再之后是一个 Linear Embedding 层,这两个加在一起其实就是一个 Patch Merging 层(至少上面的源码中是这么实现的)。这一部分的源码如下:
Patch Merging模块主要在每个Stage一开始降低图片分辨率。类似于 CNN 中 Pooling 层。Patch Merging 是主要是通过 nn.Unfold函数实现降采样的,nn.Unfold的功能是对图像进行滑窗,相当于卷积操作的第一步,因此它的参数包括窗口的大小和滑窗的步长。根据源码中给出的超参我们知道第一步降采样的比例是 (后面步骤的降采样比例为 ),因此经过 nn.Unfold 之后会得到 个长度为 的特征向量,其中 是输入到这个 stage 的 Feature Map 的通道数,即 RGB 图像的通道数。PatchMerging 可以表示为下式。
接着的 view permute是将得到的向量序列还原到 的二维矩阵,linear是将长度是 的特征向量映射到 out_channels=96的长度,因此 stage-1 的 Patch Merging 的输出向量维度是 ,对比源码的注释,这里省略了第一个 batch的维度。即将原始图像 降采样到维度
可以看出 Patch Partition/Patch Merging 起到的作用像是 CNN 中通过带有步长的滑窗来降低分辨率,再通过 卷积来调整通道数。不同的是在 CNN 中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的响应值,而 Patch Merging 的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。在一些需要提升模型容量的场景中,我们其实可以考虑使用 Patch Merging 来替代 CNN 中的池化。

1.3 Swin Transformer 的 Stage

如我们上面分析的,图 2 中的 Patch Partition+Linaer Embedding 就是一个 Patch Merging,因此 Swin Transformer 的一个 stage 便可以看做由 Patch MergingSwin Transformer Block 组成,源码如下。

1.4 Swin Transformer Block

Swin Transformer Block 是该算法的核心点,它由窗口多头自注意层(window multi-head self-attention, W-MSA)和移位窗口多头自注意层(shifted-window multi-head self-attention, SW-MSA)组成,如右图所示。由于这个原因,Swin Transformer 的层数要为 2 的整数倍,一层提供给 W-MSA,一层提供给 SW-MSA
从右图中我们可以看出输入到该 stage 的特征 先经过 LN 进行归一化,再经过 W-MSA 进行特征的学习,接着的是一个残差操作得到 。接着是一个 LN,一个 MLP 以及一个残差,得到这一层的输出特征 SW-MSA 层的结构和 W-MSA 层类似,不同的是计算特征部分分别使用了 SW-MSA 和 W-MSA,可以从上面的源码中看出它们除了 shifted 的这个 bool 值不同之外,其它的值是保持完全一致的。这一部分可以表示为下式。
notion image
一个 Swin Block 的源码如下所示。

1.5 W-MSA

窗口多头自注意力(Window Multi-head Self AttentionW-MSA),顾名思义,就是个在窗口的尺寸上进行 Self-Attention 计算,与 SW-MSA 不同的是,它不会进行窗口移位,它们的源码如下。我们这里先忽略shiftedTrue的情况,这一部分会放在 1.6 节去讲。
forward函数中首先计算的是 Transformer 中介绍的 三个特征。所以 to_qkv()函数就是一个线性变换,这里使用了一个实现小技巧,即只使用了个隐层节点数为 inner_dim*3的线性变换,然后再使用 chunk(3)操作将它们切开。因此 qkv是一个长度为 的 Tensor,每个 Tensor 的维度是 。其中 为特征图的长、宽, 包含了多个头输出维度。
之后的 map 函数是实现 W-MSA 中的 W 最核心的代码,它是通过 einopsrearrange实现的。einops 是一个可读性非常高的实现常见矩阵操作的 python 包,例如矩阵转置,矩阵复制,矩阵 reshape 等操作。最终通过这个操作得到了 3 个独立的窗口的权值矩阵,它们的维度是 ,这 4 个值的意思分别是:
  • : 多头自注意力的头的个数:
  • : 窗口的个数,首先通过 Patch Merging 将图像的尺寸降到 ,因为窗口的大小为 ,所以总共剩下 个窗口;
  • :窗口的像素的个数;
  • :隐层节点的个数。
Swin Transformer 将计算区域控制在了以窗口为单位的策略极大减轻了网络的计算量,将复杂度降低到了图像尺寸的线性比例。传统的 MSAW-MSA 的复杂度分别是:
上式的计算忽略了 softmax 的占用的计算量,这里以 为例,它的具体构成如下:
  1. 代码中的to_qkv()函数,即用于生成 三个特征向量:其中 的维度是 的维度是 ,那么这三项的复杂度是
  1. 计算 的维度均是 ,因此它的复杂度是
  1. softmax 之后乘 得到 :因为 的维度是 ,所以它的复杂度是 ;
  1. 矩阵得到最终输出,对应代码中的to_out()函数:它的复杂度是
通过 Transformer 的计算公式 ,我们可以有更直观一点的理解,在 Transformer 一文中我们介绍过 Self-Attention 是通过点乘的方式得到 Query 矩阵和 Key 矩阵的相似度,即 。然后再通过这个相似度匹配 Value。因此这个相似度的计算时通过逐个元素进行点乘计算得到的。如果比较的范围是一个图像,那么计算的瓶颈就是整个图的逐像素比较,因此复杂度是 。而 W-MSA 是在窗口内的逐像素比较,因此复杂度是 ,其中 W-MSA 的窗口的大小。
回到代码,接着的 dots变量便是我们刚刚介绍的 操作。接着是加入相对位置编码,我们放到最后介绍。接着的 attn 以及 einsum便是完成了式上式的整个流程。然后再次使用 rearrange将维度再调整回 。最后通过 to_out 将维度调整为超参设置的输出维度的值。这里我们介绍一下 W-MSA 的相对位置编码,首先这个位置编码是加在乘以完归一化尺度之后的 dots变量上的,因此 的计算方式变为右式。因为 W-MSA 是以窗口为单位进行特征匹配的,因此相对位置编码的范围也应该是以窗口为单位,它的具体实现见下面代码。相对位置编码的具体思想参考 UniLMv2[8]。
单独的使用 W-MSA 得到的网络的建模能力是非常差的,因为它将每个窗口当做一个独立区域计算而忽略了窗口之间交互的必要性,基于这个动机,Swin Transformer 提出了 SW-MSA

相对位置编码

notion image
相对位置编码插入的位置与绝对位置编码不同。首先创建一个可以学习的矩阵变量 ,通过 relative_indices 进行索引,得到偏置 。假设需要添加位置编码的窗口大小为 ,那么需要生成注意力的像素个数有最后两维为 。我们通过规则生成的 relative_indices 维度为 ,其中 代表从二维矩阵 索引的两个坐标。由规则生成的索引矩阵由于有正负值,因此我们需要将其加上矩阵的最小值 ,使矩阵的所有元素变为正值,因此矩阵的最大值为 ,也就是说,需要索引的矩阵 的最大 index 为 ,因此我们只需要一个 大小的矩阵 即可。

1.6 SW-MSA

notion image
图4
图4
SW-MSA 的的位置是接在 W-MSA 层之后的,因此只要我们提供一种和 W-MSA 不同的窗口切分方式便可以实现跨窗口的通信。SW-MSA 的实现方式如图 4 所示。我们上面说过,输入到 Stage-1 的图像尺寸是 的(图 4.(a)),那么 W-MSA 的窗口切分的结果如图 4.(b) 所示。那么我们如何得到和 W-MSA 不同的切分方式呢?SW-MSA 的思想很简单,将图像各循环上移和循环左移半个窗口的大小,那么图 4.(c) 的蓝色和红色区域将分别被移动到图像的下侧和右侧,如图 4.(d)。那么在移位的基础上再按照 W-MSA 切分窗口,就会得到和 W-MSA 不同的窗口切分方式,如图 4.(d) 中红色和蓝色分别是 W-MSASW-MSA 的切分窗口的结果。这一部分可以通过 pytorch 的 roll函数实现,源码中是 CyclicShift函数。
其中displacement的值是窗口值除 2。
这种窗口切分方式引入了一个新的问题,即在移位图像的最后一行和最后一列各引入了一块移位过来的区域,如图 4.(d)。根据上面我们介绍的 用于逐像素计算相似度的自注意力机制,图像两侧的像素互相计算相似度是没有任何作用的,即只需要对比图 4.(d)中的一个窗口中相同颜色的区域,我们以图 4.(d)左下角的区域 (1) 为例来说明 SW-MSA 是怎么实现这个功能的。
区域 (1) 的计算如图所示。首先一个 大小的窗口通过线性运算得到 三个权值,如我们介绍的,它的维度是 。在这个 中,前 个是按照滑窗的方式遍历区域 (1) 中的前 个像素得到的,后 个则是遍历区域 (1) 的下半部分得到的,此时他们对应的位置关系依旧保持上黄下蓝的性质。接着便是计算,在图中相同颜色区域的相互计算后会依旧保持颜色,而黄色和蓝色区域计算后会变成绿色,而绿色的部分便是无意义的相似度。在论文中使用了 upper_lower_mask将其掩码掉,upper_lower_mask是由 和无穷大组成的二值矩阵,最后通过单位加之后得到最终的 dots变量。upper_lower_mask的计算方式如下。
notion image
区域 (2) 的计算方式和区域 (1) 类似,不同的是区域 (2) 是循环左移之后的结果,如图 6 所示。因为 (2) 是左右排列的,因此它得到的 是条纹状的,即先逐行遍历,在这 7 行中,都会先遍历到 4 个黄的,然后再遍历到 3 个红的。两个条纹状的矩阵相乘后,得到的相似度矩阵是网络状的,其中橙色表示无效区域,因此需要网格状的掩码 left_right_mask来进行覆盖。
notion image
left_right_mask的生成方式如下面代码。关于这两个掩码的值,你可以自己代入一些值来验证,你可以设置一下window_size的值,然后displacement的值设为window_size的一半即可。
这一部分操作中,窗口移位和 mask 的计算是在WindowAttention类中的第一个if shifted = True中实现的。掩码的相加是在第二个 if 中实现的,最后一个 if 则是将图像再复原回原来的位置。
截止到这,我们从头到尾对 Swin-T 的 stage-1 进行了完成的梳理,后面 3 个 stage 除了几个超参以及图像的尺寸和 stage-1 不同之外,其它的结构均保持一致,这里不再赘述。

1.7 输出层

最后我们介绍一下 Swin Transformer 的输出层,在 stage-4 完成计算后,特征的维度是 。Swin Transformer 先通过一个 Global Average Pooling 得到长度为 的特征向量,再通过一个 LN 和一个全连接得到最终的预测结果,如下式。

Swin Transformer 家族

Swin Transformer 共提出了 4 个不同尺寸的模型,它们的区别在于隐层节点的长度,每个 stage 的层数,多头自注意力机制的头的个数,具体值见下面代码。
因为 Swin Transformer 是一个多阶段的网络框架,而且每一个阶段的输出也是一组 Feature Map,因此可以非常方便的将其迁移到几乎所有 CV 任务中。作者的实验结果也表明 Swin Transformer 在检测和分割领域也达到了 state-of-the-art 的水平。

总结

Swin Transformer 是近年来为数不多的读起来让人兴奋的算法,它让人兴奋的点有三:
  1. 解决了长期困扰业界的 Transformer 应用到 CV 领域的速度慢的问题;
  1. Swin Transformer 的设计非常巧妙,具有创新又紧扣 CNN 的优点,充分考虑的 CNN 的位移不变性,尺寸不变性,感受野与层次的关系,分阶段降低分辨率增加通道数等特点,没了这些特点 Swin Transformer 是没有勇气称自己一个 backbone 的;
  1. 其在诸多 CV 领域的 STOA 的表现。
当然我们对 Swin Transformer 还是要站在一个客观的角度来评价的,虽然论文中说 Swin Transformer 是一个 backbone,但是这个评价还为时尚早,因为
  1. Swin Transformer 并没有提供一个像反卷积那样的上采样的算法,因此对于这类需求的 backbone Swin Transformer 并不能直接替换,也许可以采用双线性差值来实现,但效果如何还需要评估。
  1. 从 W-MSA 一节中我们可以看出每个窗口都有一组独立的 ,因此 Swin Transformer 并不具有 CNN 一个特别重要的特性:权值共享。这也造成了 Swin Transformer 在速度上和还和同级别的 CNN 仍有不小的差距。所以就目前来看,在嵌入式平台上 CNN 还有着不可撼动的地位。
  1. Swin Transformer 在诸多的 CNN 已经取得非常好的效果的领域还未得到充分验证,如果只会掀起了一股使用 Swin Transformer 或其衍生算法在 CV 领域灌水风,那时候我们就可以说:Swin Transformer 的时代到来了。

Reference

[1] Vaswani, Ashish, et al. “Attention is all you need.” arXiv preprint arXiv:1706.03762 (2017).
[2] Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
[3] Chen, Mark, et al. “Generative pretraining from pixels.” International Conference on Machine Learning. PMLR, 2020.
[4] Liu, Ze, et al. “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” arXiv preprint arXiv:2103.14030 (2021).
[5] Ba J L, Kiros J R, Hinton G E. Layer normalization[J]. arXiv preprint arXiv:1607.06450, 2016.
[6] T.-Y. Lin, P. Dollar, R. Girshick, K. He, B. Hariharan, and ´ S. Belongie. Feature pyramid networks for object detection. In CVPR, 2017. 2, 4, 5, 7
[7] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.
[8] Bao, Hangbo, et al. “Unilmv2: Pseudo-masked language models for unified language model pre-training.” International Conference on Machine Learning. PMLR, 2020. >
 

© Lazurite 2021 - 2024