Introduction of MAE self-monitoring algorithm and its recurrence based on easycv

Time:2022-7-1

introduction

Self supervised learning can use a large number of unlabeled data for characterization learning, and then fine tune parameters on specific downstream tasks. In this way, we can obtain better accuracy than supervised learning method on less labeled data. In recent years, self supervised learning has attracted more and more attention. For example, Yann Lecun also said on AAAI that self supervised learning is the general trend of the future. A series of work has emerged in the CV field, such as swav, Moco, Dino, Moby and so on. MAE is another masterpiece of Kaiming in the field of self supervised learning after Moco. First of all, this paper will interpret Mae, and then answer some questions based on the accuracy reproduction process of easycv library.

summary

Mae’s approach is very simple: randomly mask out some patches in the image, and then reconstruct these missing areas through the model. The design includes two cores: 1 Asymmetric encoding decoding structure 2 Use a higher mask rate (75%). Through these two designs, MAE can achieve more than 3 times the training speed and higher accuracy in the pre training process. For example, Vit huge can achieve an accuracy rate of 87.8% on imagenet-1k data.

Model disassembly

MAE is a kind of autoencoder, which consists of encoder and decoder. Similar to the common self encoder, MAE will first map the picture patch to the hidden space through the encoder. Then, the feature variables in the hidden space are reconstructed into a picture patch based on the decoder. The difference from common self encoder is asymmetric encoding and decoding structure. This asymmetry is mainly reflected in the following two points:

  1. Lightweight decoder architecture
  2. In the encoder stage, only the picture patch that has not been masked is used as input. At the decoder stage, the hidden variable output by the encoder and mask token are used as inputs to reconstruct the completed picture.

Mask policy

First, the picture is directly divided into non overlapping patches using the vit method (for example, vit-b will divide the picture into 16×16 image blocks), and then these patches are sampled through the uniform sampling strategy, and the unselected parts are discarded. The mask strategy adopted by MAE has the following two characteristics:

  1. In the algorithm, 75% of the masking ratio is used to discard the image patch. The author points out that the high masking ratio can effectively reduce the redundancy of input, so that the reconstruction task can not be completed by simply referring to adjacent patches. This view is also proved by experiments.

    The experiment on masking ratio is the most wonderful part of MAE. With the increase of mask ratio, the precision of fine tuning and linear proing gradually rises, and even does not decline to 75%. This breaks the practices of Bert (15%) and Beit (40%), and further replicates the success of mask pre training in NLP in CV.
  2. The uniform sampling strategy can effectively avoid potential center bias (the discarded patches are close to the center of the picture). The experiment of eliminating mask policy is shown in the following table.

encoder

Mae encoder adopts vit structure. After the image patch is sampled, only 25% of the image patches that are not masked are retained as input. After encoding through linear projection, add positive embedding, and then input them into a series of transformer blocks. Compared with the practice of using mask token to replace the masked area in Bert, MAE encoder directly discards the part of the mask, which can effectively reduce the computational resources and training time consumed in the pre training process.

In this paper, the author conducted ablation experiments on whether the encoder retains the mask token. It can be seen that discarding the mask token at the encoder stage will not affect the representation ability of the pre training model, and can significantly accelerate the training process.

decoder

Mae decoder consists of a series of transfomer blocks. Different from the encoder, the input of MAE decoder not only includes the features of the image patch not masked after being encoded by the encoder, but also includes the part masked. For the input of the missing part of the mask, a shared parameter and learnable mask token will be used as the input. In addition, in order to ensure that different mask tokens can distinguish different positions in the image, the overall input will be added with positive embedding before input to the decoder.

In Mae, the decoder will only be used for image reconstruction in the pre training stage. In this paper, a lightweight decoder structure is adopted, and the computation of each token is only less than 10% of that of the decoder. Through this design, even if a complete number of tokens are used as input in the decoding phase, the consumption of computing resources will not increase significantly.

In this paper, the author makes a comparative experiment on the depth and width dimensions of the decoder. It can be seen that a lighter decoder is enough to effectively represent the model learning.

Rebuild goals

The goal of MAE pre training task is to reconstruct the pixel values masked out. After the MAE decoder outputs the representation of each image patch, it will be mapped into a vector (pxpx3) with the same dimension as the number of image pixels through a linear projection layer. Only MSE is used as the loss function to calculate the prediction vector and MSE loss before the pixel value is masked.

It should be noted that the author uses the normalized image patch as the reconstruction target. Experiments show that this approach can improve the representation ability of the model.

Model evaluation

In addition to evaluating the representation ability of the model from the perspectives of linear probing and finishing, the partial fine tuning method is also used for evaluation. Compared with linear probing, which was widely used before, it can better reflect the representation ability of the pre training model for nonlinear features. As can be seen from the figure below, the accuracy of the MAE algorithm for fintune of only one transformer block is improved from 73.5% to 81%. Meanwhile, compared with mocov3, although mocov3 has higher accuracy in linear probing, the accuracy of MAE in partial fine tuning is higher than mocov3. It can be seen that although MAE is weaker than mocov3 in representing linear features, it has better ability to represent nonlinear features.

Introduction to easycv

Easycv is an open-source all in one visual algorithm modeling tool based on pytoch and with self supervised learning and transformer technology as the core. At the data level, easycv provides the abstraction of different data_sources, supports a variety of open-source data sets, such as cifar, Imagenet, coco, etc., and abstracts various data preprocessing into several independent pipelines, which can flexibly configure the data preprocessing process through the configuration file. At the API level, it provides a unified API for training, evaluation, model export and prediction. Therefore, based on easycv, we only need to implement the code of the model part, and we can easily complete the reproduction of MAE.

In addition, easycv supports easy deployment (such as pai-dlc) in aliyun Pai products. Multiple machines or groups of experiments can be conducted on DLC at the same time without unnecessary modifications to speed up the replication progress.

Recurrence process & pit stepping summary

Next, we will introduce how to reproduce the MAE algorithm in the easycv framework and summarize the stepping pit. First, we will explain the overall process of pre training.

  1. Divide the input image into different patches, map the patch through linear projection, and add positive embedding to get the image token
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
  1. The image token is randomly masked at a ratio of 75%. The randomly generated tensor noise is argsort to complete the random mask of image patch. Among them, it should be noted that two parameters mask and IDs are returned in the function_ restore。 Mask records the position of the mask patch in the original picture for subsequent loss function calculation. ids_ Restore records the position of the image token passed into the encoder in the original image, which is used for subsequent unshuffle operations before the decoder.
def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore
  1. Input the reserved image token to the encoder to get the image embedding

         # apply Transformer blocks
         for blk in self.blocks:
             x = blk(x)
         x = self.norm(x)
  2. Unshuffle the image embedding and mask token together, and then add the positive embedding to the decoder
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
    x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
    x_,
    dim=1,
    index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)

# add pos embed
x = x + self.decoder_pos_embed
  1. Calculate MSE loss from the output vector and normalized image patch, and back propagate the updated gradient. There are two points to pay attention to when calculating loss. 1. First, we need to normalize the image patch as the target. 2. When calculating the loss function, only the loss function is calculated for the part of the mask patch.
    def forward_loss(self, imgs, pred, mask):
        """compute loss
        Args:
            imgs: (N, 3, H, W)
            pred: (N, L, p*p*3)
            mask: (N, L), 0 is keep, 1 is remove,
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target)**2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

Precision reproduction

reference resourceshttps://github.com/facebookre…, we have reproduced the accuracy of Vit base and Vit large fintune on imagenet1k under the configuration of a single eight card V100. The results are shown in the following table.

Algorithm ImageNet1K Top-1(%) ImageNet1K Top-1(%) config
vit-b 400 epoch 83.13 mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune
vit-b 1600 epoch 83.55 mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune
vit-l 1600 epoch 85.70 mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune

Let’s share some problems and parameters encountered in the reproduction process. If there are any problems, please point them out.

  1. In fintune, the implementation of MAE uses the data augmentation method of mixup+cutmix. If only mixup is used, the accuracy will be reduced.
  2. In fintune, the MAE uses the average of all token features as the input of the classification head, while the accuracy will decrease when CLS token is used as the input.
  3. During pre training, make sure you use a large enough weight_ Decay (for example, it is officially set to 0.05), otherwise the problem of gradient explosion is easy to occur in the downstream task fintune. When the downstream classification task fintune is set to a smaller weight, the accuracy will be improved. (when PS reproduces vit-l, set weight_decay 0.01 in Pretrain, and gradient explosion will occur in fintune.)

The following table shows the reproduction process of vit-b model, and the accuracy improvement of the above process

parameter setting ImageNet1K Top-1(%)
vit-b 1600 epoch(mixup,cls token) 83.21
vit-b 1600 epoch(mixup+cutmix,cls token) 83.36
vit-b 1600 epoch(mixup+cutmix,global_pool) 83.55

We are in the open source frameworkEasyCVThe MAE algorithm is reproduced in. Refer to self-monitoring on GitHub for detailed parameter configuration and experiment logmodelzoo

Tutorial

Next, we will introduce how to pre train and fine tune the MAE algorithm based on easycv through a practical example, which can also be found in thelinkView detailed steps.

1、 Install dependent packages

If you are running in a local development environment, you can refer to thislinkInstallation environment. If pai-dsw is used for experiments, there is no need to install relevant dependencies, and the relevant environment has been built in the pai-dsw docker.

2、 Data preparation

Self supervised training only needs to provide unmarked pictures, which you can downloadImageNetData, or use your own image data. You need to provide a folder path P containing several pictures and a file list. The file list is the path of each picture to the picture directory P.

An example of the picture folder structure is shown below. The folder path is/ images

images/
├── 0001.jpg
├── 0002.jpg
├── 0003.jpg
|...
└── 9999.jpg

The list of documents is as follows:

0001.jpg
0002.jpg
0003.jpg
...
9999.jpg

In order to quickly walk through the process, we also provide a small sample data set. Execute the following command to download and decompress:

wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz
tar -zxf imagenet_raw_demo.tar.gz
mv imagenet_raw_demo  imagenet_raw

3、 Model pre training

Take vit base as an example. In easycv, configuration files are used to realize the configuration of model parameters, data input, augmentation methods and training strategies. Only by modifying the parameter settings in the configuration files, the experimental configuration can be completed for training. You can download the sample configuration file directly.

rm -rf mae_vit_base_patch16_8xb64_1600e.py
wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_1600e.py

View easycv installation location

View easycv installation location

import easycv
print(easycv.__file__)
Execute training command

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_1600e.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

4、 Model tuning

1. Modify the fields of the pre training model obtained in the previous step for the fintune task.

import torch 
weight_path = 'work_dir/selfsup/jpg/mae/epoch_5.pth'
state_dict = torch.load(weight_path)['state_dict']
state_dict_out = {}
for key in state_dict:
    state_dict_out[key.replace('encoder.','')] = state_dict[key]
torch.save(state_dict_out,weight_path)

2. Download the classification task sample configuration file

rm -rf mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py
wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py

3. Execute training command

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

END

Easycv will carry out a series of work introductions on the reproduction of SOTA papers in the follow-up meeting. We welcome your attention and use, as well as your feedback and improvement suggestions from various dimensions and technical discussions. At the same time, we very much welcome and look forward to the participation of colleagues interested in the construction of open source communities.

Project open source address:https://github.com/alibaba/Ea…
Nail Q & A exchange group: 41783266