S2-mlp V1 & V2 of vision MLP: spatial shift MLP

Time:2022-4-19

S2-mlp V1 & V2 of vision MLP: spatial shift MLP architecture for vision

Original document:https://www.yuque.com/lart/pa…

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

Two articles on s2-mlp will be summarized here. The core idea of these two articles is the same, that is to replace spatial MLP based on spatial offset operation.

Understand the article from the abstract

V1

Recently, visual Transformer (ViT) and its following works _abandon the convolution and exploit the self-attention operation_, attaining a comparable or even higher accuracy than CNNs. More recently, MLP-Mixer _abandons both the convolution and the self-attention operation_, proposing an architecture containing only MLP layers.
To achieve cross-patch communications, it devises an additional token-mixing MLP besides the channel-mixing MLP. It achieves promising results when training on an extremely large-scale dataset. _But it cannot achieve as outstanding performance as its CNN and ViT counterparts when training on medium-scale datasets such as ImageNet1K and ImageNet21K_. _The performance drop of MLP-Mixer motivates us to rethink the token-mixing MLP_.

This paper introduces the main content of this paper, that is, improved spatial MLP.

We discover that the token-mixing MLP is a variant of the depthwise convolution with a global reception field and spatial-specific configuration. But _the global reception field and the spatial-specific property make token-mixing MLP prone to over-fitting_.

The problem of spatial MLP is pointed outIts global receptive field and spatial specific attributes make the model easy to over fit

In this paper, we propose a novel pure MLP architecture, spatial-shift MLP (S2-MLP). Different from MLP-Mixer, our S2-MLP only contains channel-mixing MLP.

Only channel MLP is mentioned here, which indicates that a new method is thought of to expand the receptive field of channel MLP, and point operation can be retained.

We utilize a _spatial-shift operation for communications between patches_. It has a local reception field and is spatial-agnostic. It is parameter-free and efficient for computation.

Leads to the core content of this paper, that is, the spatial offset operation mentioned in the title. It seems that this operation has no parameters and is only a processing means to adjust the characteristics.
For spatial shift operations, please refer to the following articles:https://www.yuque.com/lart/architecture/conv#i8nnp

The proposed S2-MLP attains higher recognition accuracy than MLP-Mixer when training on ImageNet-1K dataset. Meanwhile, S2-MLP accomplishes as excellent performance as ViT on ImageNet-1K dataset with considerably _simpler architecture and fewer FLOPs and parameters_.

V2

Recently, MLP-based vision backbones emerge. MLP-based vision architectures with less inductive bias achieve competitive performance in image recognition compared with CNNs and vision Transformers. Among them, spatial-shift MLP (S2-MLP), adopting the straightforward spatial-shift operation, achieves better performance than the pioneering works including MLP-mixer and ResMLP. More recently, using smaller patches with a pyramid structure, Vision Permutator (ViP) and Global Filter Network (GFNet) achieve better performance than S2-MLP.

Here comes the pyramid structure. It seems that V2 version will use a similar structure.

In this paper, we improve the S2-MLP vision backbone. We expand the feature map along the channel dimension and split the expanded feature map into several parts. We conduct different spatial-shift operations on split parts.

The spatial offset strategy is still continued, but I don’t know how it is changed compared with the V1 version

Meanwhile, we _exploit the split-attention operation to fuse these split parts_.

Split attention (resnest) is also introduced to fuse groups. Do you want to use parallel branches here?

Moreover, like the counterparts, we adopt _smaller-scale patches and use a pyramid structure for boosting the image recognition accuracy_.
We term the improved spatial-shift MLP vision backbone as S2-MLPv2. Using 55M parameters, our medium-scale model, S2-MLPv2-Medium achieves an 83.6% top-1 accuracy on the ImageNet-1K benchmark using 224×224 images without self-attention and external training data.

In my opinion, compared with V1, V2 mainly draws on some ideas of cyclefc and makes adaptive adjustments. The overall change has two aspects:

  1. The idea of multi branch processing is introduced, and split attention is applied to fuse different branches.
  2. Inspired by the existing work, smaller patch and hierarchical pyramid structure are used.

primary coverage

Core structure comparison

In V1, the overall process continues the idea of MLP mixer and still maintains a straight cylindrical structure.
S2-mlp V1 & V2 of vision MLP: spatial shift MLP

Structure diagram of MLP mixer:
S2-mlp V1 & V2 of vision MLP: spatial shift MLP

As can be seen from the figure, unlike the pre norm structure in MLP mixer, s2mlp uses the post norm structure.
In addition, the changes of s2mlp mainly focus on the location of spatial MLPSpatial-MLP(Linear->GeLU->Linear)Convert toSpatial-Shifted Channel-MLP(Linear->GeLU->Spatial-Shift->Lienar)
The core pseudo code for spatial offset is as follows:

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

As you can see, here will beThe input is divided into four different groups, each offset along different axes (H and W axes), due to implementation reasons, there will be duplicate values in the boundary part.The number of groups depends on the number of directions. Here, 4 is used by default, that is, offset in four directions.
Although from the perspective of a single spatial offset module, only adjacent patches are associated, from the perspective of the overall stacked structure, an approximate long-distance interaction process can be realized.

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

In V2 version, compared with V1 version, the strategy of multi branch processing is introduced, and the pre norm form is used in the structure.

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

The construction idea of multi branch structure is very similar to cyclefc. Different branches use different processing strategies, and split attention is used for multi branch integration.

Split-Attention: Vision Permutator (Hou et al., 2021) adopts split attention proposed in ResNeSt (Zhang et al., 2020) for enhancing multiple feature maps from different operations. This paper uses for reference to fuse multiple branches.
Main operation process:

  1. Enter $k $feature maps (from different branches) $\ mathbf {x} = \ {x_k \ in \ mathbb {r} ^ {n \ times C} \} ^ {K}_ {k=1}, \, N=HW$
  2. Add up the results after summing the columns of all special diagnosis charts: $a \ in \ mathbb {r} ^ {C} = \ sum_ {k=1}^{K}\sum_ {n=1}^{N}\mathbf{X}_ {k}[n, :]$
  3. Through the transformation of stacked full connection layers, channel attention Logits for different characteristic graphs are obtained: $\ hat {a} \ in \ mathbb {r} ^ {KC} = \ sigma (a w_1) W_ 2, \, W_ 1 \in \mathbb{R}^{C \times \bar{C}}, \, W_ 2 \in \mathbb{R}^{\bar{C} \times KC}$
  4. Use reshape to adjust the shape of the attention vector: $\ hat {a} \ in \ mathbb {r} ^ {KC} \ rightarrow \ hat {a} \ in \ mathbb {r} ^ {K \ times C}$
  5. Use softmax to calculate along the index $k $to obtain the normalized attention weight for different samples: $\ bar {a} [:, C] \ in \ mathbb {r} ^ {K} = \ text {softmax} (\ hat {a} [:, C])$
  6. The weighted sum of the input $k $characteristic graphs yields the result $y $, and the result of one line can be expressed as: $y [n,:] \ in \ mathbb {r} ^ {C} = \ sum_ {k=1}^{K} X_ {k}[n, :] \odot \bar{A}[k, :]$

However, it should be noted that the third branch here is an identity branch, which directly takes part of the input channels, which continues the idea of GhostNet. Unlike cyclefc, it uses an independent channel MLP.

GhostNetCore structure:
S2-mlp V1 & V2 of vision MLP: spatial shift MLP

The core pseudo code of the multi branch structure is as follows:
S2-mlp V1 & V2 of vision MLP: spatial shift MLP

Other details

Relationship between spatial shift and depthwise revolution

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

In fact, the offsets in the four directions can be realized through specific convolution kernel Construction:

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

Therefore, the packet space offset operation can be realized by specifying the corresponding convolution kernel for different packets of depthwise revolution.

In fact, there are many methods to realize migration. In addition to the slice index and the depthwise revolution of constructing the core mentioned in the article, it can also be realized by groupingtorch.rollAnd custom offsetdeform_conv2dTo achieve.

import torch
import torch.nn.functional as F
from torchvision.ops import deform_conv2d

xs = torch.meshgrid(torch.arange(5), torch.arange(5))
x = torch.stack(xs, dim=0)
x = x.unsqueeze(0).repeat(1, 4, 1, 1).float()

direct_shift = torch.clone(x)
direct_shift[:, 0:2, :, 1:] = torch.clone(direct_shift[:, 0:2, :, :4])
direct_shift[:, 2:4, :, :4] = torch.clone(direct_shift[:, 2:4, :, 1:])
direct_shift[:, 4:6, 1:, :] = torch.clone(direct_shift[:, 4:6, :4, :])
direct_shift[:, 6:8, :4, :] = torch.clone(direct_shift[:, 6:8, 1:, :])
print(direct_shift)

pad_ X = f.pad (x, pad = [1, 1, 1, 1], mode = "replicate") # here, you need to use padding to preserve the boundary data

roll_shift = torch.cat(
    [
        torch.roll(pad_x[:, c * 2 : (c + 1) * 2, ...], shifts=(shift_h, shift_w), dims=(2, 3))
        for c, (shift_h, shift_w) in enumerate([(0, 1), (0, -1), (1, 0), (-1, 0)])
    ],
    dim=1,
)
roll_shift = roll_shift[..., 1:6, 1:6]
print(roll_shift)

k1 = torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k2 = torch.FloatTensor([[0, 0, 0], [0, 0, 1], [0, 0, 0]]).reshape(1, 1, 3, 3)
k3 = torch.FloatTensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k4 = torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 1, 0]]).reshape(1, 1, 3, 3)
weight = torch. Cat ([K1, K1, K2, K2, K3, K3, K4, K4], dim = 0) # each output channel corresponds to one input channel
conv_shift = F.conv2d(pad_x, weight=weight, groups=8)
print(conv_shift)

offset = torch.empty(1, 2 * 8 * 1 * 1, 1, 1)
for c, (rel_offset_h, rel_offset_w) in enumerate([(0, -1), (0, -1), (0, 1), (0, 1), (-1, 0), (-1, 0), (1, 0), (1, 0)]):
    offset[0, c * 2 + 0, 0, 0] = rel_offset_h
    offset[0, c * 2 + 1, 0, 0] = rel_offset_w
offset = offset.repeat(1, 1, 7, 7).float()
weight = torch.eye(8).reshape(8, 8, 1, 1).float()
deconv_shift = deform_conv2d(pad_x, offset=offset, weight=weight)
deconv_shift = deconv_shift[..., 1:6, 1:6]
print(deconv_shift)

"""
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
"""

Influence of offset direction

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

The experiment was run on a subset of Imagenet.

Ablation experiments were carried out in V1 for different offset directions. In the model here, channels are grouped according to the number of directions. As can be seen from the results:

  • Offsets do provide performance gains.
  • A and B: there is little difference between the four directions and the eight directions.
  • E and F: horizontal offset is better.
  • C and E / F: the offset of two axes is better than that of a single axis.

Input size and the influence of patchsize

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

The experiment was run on a subset of Imagenet.

In V1, after the patchsize is fixed, the performance of different input sizes WXH is also different. The effect of too large patchsize is not good, and more detailed information will be lost, but it can effectively improve the reasoning speed.

Effectiveness of pyramid structure

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

In V2, two different structures are constructed. One has a smaller patch and uses a pyramid structure, and the other has a larger patch and does not use a pyramid structure. It can be seen that while benefiting from the performance enhancement of detailed information brought by small patchsize and the better computational efficiency brought by pyramid structure, the former obtains better performance.

Split attention effect

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

V2 add split attention and features directly to obtain average comparison. It can be seen that the former is better. However, the parameters here are different. In fact, a more reasonable comparison should at least add several layers of structures with parameters to integrate the characteristics of three branches.

Effectiveness of three branch structure

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

In this section, we evaluate the influence of removing one of them However, it does not explain the adjustment mode of other structures after removing specific branches.

experimental result

The experimental results can be directly seen in the table of V2 paper:

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

S2-mlp V1 & V2 of vision MLP: spatial shift MLP

link