swin-Transformer论文详解

作者: pdnbplus | 发布时间: 2024/06/18 | 阅读量: 359

swin-Transformer论文详解 -- 潘登同学的深度学习笔记

前言

swin-Transformer作为CVPR 21的最佳论文,在几乎所有下游任务都表现地很出色;swin-Transformer的全称是Hierarchical Vision Transformer using Shifted Windows; 核心是怎么采用Hierarchical的方式来应用Transformer,从而是的CV与NLP更好的融合;

在VIT那篇论文中,研究团队只是做了VIT用于图像分类的任务,并没有把所有任务都刷过,所以大家还是对Transformer用于CV领域有所担心,而swin-Transformer就是解决这种担心的;

将transformer用于CV领域还是有两大方面的难题的

  • 语义尺度问题(对于空间上靠前的物体与靠后的物体的尺度是不同的,但是可能表示同样的语义信息)
  • 图片大小问题(因为VIT做了patch,得到的token长度都是定死的)

目标检测与语义分割的经典架构

  • 目标检测要框不同尺寸的物体,常用架构就是FPN
  • 语义分割不仅要找到物体,还要画出来,常用架构就是U-Net

VIT的问题

  • 在做patch的时候,始终是在整张图上做patch,学的始终是整张图的信息
  • 一个patch中的序列长度,是随着整张图的长度变化而平方级别地变大
  • 没有CNN的先验知识

为什么把NLP的模型用于CV领域会表现的很好

  • 因为Transformer具有很强地连接上下文的能力,在一张图片里面,所有物体不是孤立的而是有关系的

在这里插入图片描述

网络架构

在这里插入图片描述

以ImageNet的图片为例,输入图片大小为224x224x3;

  1. Patch projection层: 每个patch的大小是4x4,那么就有56x56个patch,将patch展平得到4x4x3=48长度的token,所以输入就是[56x56]x48;
  2. Linear Embedding层: 接一层Embedding,得到[56x56]x96;
  3. Swin transformer Block:因为本质是transformer,所以输入输出维度一致,先不管;
  4. Patch Merging:是一种下采样的方式,这里是下采样两倍,就是把原特征图以2x2为一块框住,标上序号,将所有序号为1的挑出来作为一个张量,以此类推,得到$\frac{W}{2}\frac{H}{2}C$的四个张量,将其拼接得到$\frac{W}{2}\frac{H}{2}4C$; 但是为了与CNN的思想一致,就是一次改变通道数只是想让其翻倍,后面接了1x1的卷积,得到[28x28]x192;
  5. 以恰当的方式重复4次以上过程
  6. 接global average pooling将最后[7x7]x768的输出变为[1x1]x768;
    1. 如果是做分类的话,接FC得到[1x1]x1000;

Swin transformer Block

Swin transformer Block是Swin transformer的主要贡献,核心是只对每个窗口做自注意力,但是一个self-attention只能做一个序列,这里一个窗口就是一个序列,要做一个图的所有窗口,就要用多头注意力(这里的多头与之前的不一样,之前主要是为了并行计算的,现在真就是多头)

  • 最小的计算单元是一个patch: 长度为96的一个向量(作为一个token);
  • 中型计算单元为一个窗口: 总共有7x7个patch,也就对应序列长度为49;
  • (以第一层举例)总共有8x8(56/7)个窗口,那么就有64个多头;

为了说明相较于全局的patch(VIT中),基于滑动窗口的自注意力能减少计算量,作者粗略计算了一下:

  • 一个图片有hxw个patch,每个patch的通道数是C
  • 对于VIT来说,一个token进入QKV就是$3C^2$的计算复杂度,Q与所有K之间内积(向量点乘)是$hwC$, 内积结果与所有V相乘就是$hwC$,最后所有一个QK对应的所有V都做一个FC层就是$C^2$
  • 那么总共有$h*w$个token,所以总的计算复杂度就是 $$ \Omega(MSA) = hw(4C^2 + 2hwC) $$
  • 对于Swin transformer来说,可以套用上面的结果,一个窗口有MxM个patch,总共有$\frac{hw}{M^2}$个窗口,所以计算复杂度就是 $$ \Omega(W-MSA) = \frac{hw}{M^2} M^2(4C^2 + 2M^2C) = hw(4C^2 + 2M^2C) $$

  • 对于VIT来说hw就是196,而现在降为了49;

为了做滑动窗口之间的交互,在进入下一个transformer中需要对滑动窗口做shift,论文中shift的尺寸是每个窗口的一半,这里做了向下取整,使得所有滑动窗口向右、向下shift3个patch

  • 所以我们在网络架构图中看到的block单元中的transfromer个数一定是偶数个,因为一层普通的W-MSA完事之后一定要shift接一层SW-MSA

如图所示,但这样做会导致窗口的数量增加,而且窗口的大小也不相同;

在这里插入图片描述

有这样一些解决思路

  1. 直接不管,一个窗口的patch数量不一致就等于序列长度不一,直接mask掉就行;
  2. 给那些窗口大小比正常49个patch少的窗口补一圈零;

但上述的解决方案都会导致窗口数量增加,导致计算复杂度增加;

巧妙的Mask

解决方案

  • 论文给出了割补法的解决方案如图所示,就是将左上角那些残缺不全的窗口补到右下角去;

在这里插入图片描述

但随之而来的会产生一个问题: 在做自注意力的时候,可能会导致不合适的自注意力,因为图片上方可能是天空,而下方是地面,一作自注意力后就会导致原本无关的东西变得相关,显然不合理;

  • 所以论文采用了几个mask模板来解决自注意力的问题;如下图所示

在这里插入图片描述

对于上图,其中的序号表示shift完之后的窗口;

  • 0是一个完整的窗口,所以做self-attention直接做就行
  • 3和6分别是两个窗口,其中窗口3的大小是$47$个patch,而6的窗口的大小则是$37$的;将3窗口的28个向量与6窗口的21个向量拼接得到矩阵,矩阵与自己的转置相乘(第一行乘每一列放到第一行上去(经典线代)),那么得到的新矩阵就如上图的window2所示,其中黄色的部分就是3与3、6与6矩阵相乘得到的数,紫色部分就是其3与6、6与3矩阵相乘的结果;
  • 将一个事先定义好的mask模板,黄色部分是0,紫色部分是一个负很大的数与自注意力得到的结果相加,做softmax,自然就将不合适的部分mask掉了;

对于window1,window3的操作也类似,只不过在拼接矩阵的时候,对窗口1和窗口2的拼接是交错的(因为patch的序号是从左到右从上到下),所以mask模板要做对应调整,window3就是融合了两者了;

实验

实验部分就跳过,nb就完事了,在21年在ADE20k数据集上全面领跑目标检测与语义分割等下游任务;