2022-CVPR-Swin Transformer:Hierarchical Vision Transformer using Shifted Windows

1. 摘要

这篇文章[1]主要提出了一种用于 CV 任务的 Swin Transformer,它是一种使用了移动窗口的层级式 ViT。其主要思想就是借鉴于 CNN,作者想让 Transformer 能像 CNN 一样,通过层级式的特征提出从而使得提取出的特征有多尺度的概念。

作者一开始提到,Transformer 的确具有强大的能力,但直接将其用于 CV 任务存在两个问题。一个是图像中物体尺寸的问题,即使是同一个物体,在不同图像中由于拍摄距离的远近而导致物体的尺寸不同;另一个问题是图像分辨率过大时,直接将每个像素作为一个单词序列长度过大(这个问题也是 ViT 主要解决的)。因此,作者提出了层级式的 Swin Transformer,从而能学习到图像的多尺度信息,刷爆了各种 CV 任务的榜单。

2. 引言

过去 CV 领域一直是 CNN 主导,自从 NLP 领域中强大的 Transformer 出现以后,作者就想将其也用到 CV 领域。ViT 已经成功将其用到了 CV 分类任务上,但对其它任务则没有过多的探究。因此,作者在这篇论文中主要想讲的就是,Transformer 确实可以作为骨干网络,广泛用于各类 CV 任务。

3. 方法

作者首先简要介绍了一下 Swin Transformer 和 ViT 的不同,如下图所示:

  • ViT 的工作是将图像切成 16×1616 \times 16 的 patch,然后将每个 patch 当作一个单词,将整张图像展开成 patch 序列,对应于自然语言中的一句话。因此 ViT 中的自注意力是在所有 patch 间进行计算的。
  • Swin Transformer 的工作相比于 ViT 的不同之处在于,Swin Transformer 使用了层级式的多尺度结构,将图像划分成不同大小的窗口,并只在窗口中使用自注意力机制,大大减小了计算量。Swin Transformer 中的多尺度是通过将上一层的 patch 合并来实现的,作者称为 patch merging(类似于 CNN 中的池化操作)。

然后,作者介绍了 Swin Transformer 中的移动窗口的概念,如下图所示:

上图中,灰色的框是 4×44 \times 4 的 patch,红色的框是一个局部的窗口(作者在本文中设置的是一个窗口默认有 7×77 \times 7 个 patch)。而 Shift Window 的操作就是如上图所示,将图像往右下角移动了两个 patch(可以将上图红色的框想像成一个网络盖在图像上面)。通过 Shift Window,不同窗口内的 patch 也可以和其它周围窗口中的 patch 作自注意力操作。经过一层层的处理,最终每个 patch 也能和整个图像中的所有的 patch 产生交互,即感受野可以达到整张图像(类比于卷积操作)。

4. 模型

Swin Transformer 的总体架构图如下:

假设输入的图像大小的 224×224×3224 \times 224 \times 3,默认划分的 patch 大小为 4×44 \times 4,则:

  • 第一次划分后得到的 patch 序列长度为 H4×W4=56×56\frac{H}{4} \times \frac{W}{4} = 56 \times 56,经过 Linear Embedding 后通道数由 33 变为 CC,经过第一个 Swin Transformer Block 后序列数据为 H4×W4×2C\frac{H}{4} \times \frac{W}{4} \times 2C

  • 在进入第二个 Swin Transformer Block 前,先会将序列数据再拼接回 22-D 224×224224 \times 224,然后经过一个 Patch Merging 操作,它类似 Pixel Shuffle 的逆过程:

    不过不同之处在于,Pixel Shuffle 是针对像素的操作,即图中的小方块都是一个一个的像素;而 Patch Merging 是针对图像 patch 的操作,即图中的小方块都是一个一个的 patch。因此,在经过 Patch Merging 操作后,得到 44H2×W2\frac{H}{2} \times \frac{W}{2}22-D 数据,然后在通道维度上拼接在一起,故通道数由 CC 增加为 4C4C, 展开成序列数据后得到 H8×W8\frac{H}{8} \times \frac{W}{8} 长度的序列(因为 patch 大小为 4×44 \times 4H2×W2\frac{H}{2} \times \frac{W}{2} 是像素尺寸),为了和 CNN 中的池化操作保持一致,作者在 Patch Merging 后还接了一个 1×11 \times 1 的卷积操作,将 4C4C 的通道数降成 2C2C 的通道数;因此经过第二个 Swin Transformer Block 后得到的序列数据为 H8×W8×2C\frac{H}{8} \times \frac{W}{8} \times 2C

  • 经过后续其它的 Swin Transformer Block 类似。

Swin Transformer 和 ViT 不同,在将序列数据送入 Transformer 时,并没有使用 [cls] Token。

如上图右边所示,每个 Swin Transformer Block 由两个 Transformer 构成。第一个 Transformer 就是使用的是标准的多头自注意力,此时自注意力是在窗口内进行的;因此第二个 Transformer 就使用了基于移动窗口的多头自注意力(详细介绍见下文),从而实现不同窗口之间的交互。

4.1 窗口自注意力

对于模型中的 Swin Transformer Block,它其实是将标准的 Transformer 中的自注意力层替换成了一个基于自注意力的移动窗口模块。Swin Transformer Block 不再是对所的 patch 做自注意力操作,而是只针对窗口内的 patch(默认 7×77 \times 7 个 patch),这样就极大地减小了计算量。以标准的自注意力层为例,其乘法次数复杂度计算如下公式所示:

其中 hwhw 表示序列长度,CC 表示特征维度,则整个自注意力流程大致是这样的:

  • hw×chw \times c 的序列数据先分别和 33c×cc \times c 的系数矩阵做乘积,得到 33hw×chw \times cQ,K,VQ, K, V3hwC23hwC^2);
  • 然后 QQKK 在特征维度上做点积,得到 hw×hwhw \times hw 的特征数据((hw)2C(hw)^2 C);
  • hw×hwhw \times hw 的特片数据和 VV 做矩阵乘积,得到 hw×chw \times c 的输出特征((hw)2C(hw)^2C);
  • hw×chw \times c 的输出特征再过一个线性投射层(即乘一个 c×cc \times c 的系数矩阵),得到最终的输出 hw×chw \times chwC2hwC^2)。

因此,整个自注意层的乘法复杂度为:

\begin{align*} \Omega(\mathrm{SA})=4 h w C^{2}+2(h w)^{2} C \tag{1} \end{align*}

而对于小窗口的自注意力,假设小窗口内 patch 个数为 M×MM \times M,总共 hM×wM\frac{h}{M} \times \frac{w}{M} 个小窗口,将其代入到式 (1)(1) 中即得到:

Ω(WSA)=hM×wM×(4M2C2+2(M2)2C)=4hwC2+2M2hwC(2)\Omega(\mathrm{W}\mathrm{SA})=\frac{h}{M} \times \frac{w}{M} \times (4 M^2 C^{2}+2(M^2)^{2} C) = 4 h w C^{2}+2 M^{2} h w C \tag{2}

可以看到,式 (2)(2) 相比于式 (1)(1) 计算量大幅度降低,特别当 hwhw 比较大时。这便是窗口内自注意力的工作。

4.2 移动窗口

由 Figure 2 可以看到,移动窗口确实可以使不同窗口的图像 patch 产生交互,但存在一个技术上的问题。Figure 2 中窗口移动后,原来的 44 个窗口一下子变成为 99 个窗口,且大小不一。一种解决方法是对非方形的窗口将其填充补成方形,但即使是这样,窗口数量还是 99 个,计算复杂度增加了两倍。于是作者提出了使用 cyclic shift,同时使用 Masked 的 MSA(多头自注意力)来实现移动窗口:

可以看到,cycli shift 就是将左上角的 A,B,CA, B, C 窗口移到右下角,从而还是将整个图像补成 44 个大小一致的窗口。但这样存在一个问题,就是移下去地 A,B,CA, B, C 和右下角的窗口没有邻接关系,它们各自是一个窗口,因此不能做自注意力操作。所以作者就设计了掩码自注意力,从而在计算自注意力时将不同的窗口之间的自注意力丢掉,最后做完自注意力后,再将 cyclic shift 还原回去。Mask 机制如下:

图中的 Mask 是直接加到每个窗口做完自注意力的输出上的,由于自注意力后要经过一个 Softmax 层,因此加上 100-100 后 Softmax 基本就近似等于 00 了。

5. 实验

作者在文中对分类、目标检测等任务进行了实验,都取得到非常好的效果。

5.1 图像分类

5.2 目标检测

5.3 语义分割

6. 消融

作者还做了消融实验,来验证移动窗口和相对位置编码对实验性能的提升:

和 ViT 不一样,作者发现 Swin Transformer 使用相对位置编码整体效果要更好一些。

7. 结论

这篇文章主要提出了一种新的 Vision Transformer,它是一个层级式的 Transformer,计算复杂度是和图像大小成线性增长的。在性能方面,Swin Transformer 在各大 CV 任务都要远远好于其他方法。此外,作者还提到,这篇文章的重要贡献,即基于 Shift Window 的自注意力,在那些密集型的预测任务上表现出很好的性能。

不过,在模型大一统方面,Swin Transformer 由于用到了很多 CV 领域的先验知识,因此不利于和 NLP 直接统一。而 ViT 由于直接用的原版 Transformer,反而有利于模型的大一统。所以作者也提出,自己未来的工作是尝试将 Shift Window 用到 NLP 任务中,如果真的能做到且有效,那 Swin Transformer 的确也能实现模型的大一统。

附录

  1. Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., ... & Guo, B. (2021). Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 10012-10022).