A Multi-Axis Approach for Vision Transformer and MLP Models

Convolutional neural networks have been the dominant machine learning architecture for computer vision since the introduction of AlexNet in 2012. Recently, inspired by the evolution of Transformers in natural language processing, attention mechanisms have been prominently incorporated into vision models. These attention methods boost some parts of the input data while minimizing other parts so that the network can focus on small but important parts of the data. The Vision Transformer (ViT) has created a new landscape of model designs for computer vision that is completely free of convolution. ViT regards image patches as a sequence of words, and applies a Transformer encoder on top. When trained on sufficiently large datasets, ViT demonstrates compelling performance on image recognition.

While convolutions and attention are both sufficient for good performance, neither of them are necessary. For example, MLP-Mixer adopts a simple multi-layer perceptron (MLP) to mix image patches across all the spatial locations, resulting in an all-MLP architecture. It is a competitive alternative to existing state-of-the-art vision models in terms of the trade-off between accuracy and computation required for training and inference. However, both ViT and the MLP models struggle to scale to higher input resolution because the computational complexity increases quadratically with respect to the image size.

Today we present a new multi-axis approach that is simple and effective, improves on the original ViT and MLP models, can better adapt to high-resolution, dense prediction tasks, and can naturally adapt to different input sizes with high flexibility and low complexity. Based on this approach, we have built two backbone models for high-level and low-level vision tasks. We describe the first in “MaxViT: Multi-Axis Vision Transformer”, to be presented in ECCV 2022, and show it significantly improves the state of the art for high-level tasks, such as image classification, object detection, segmentation, quality assessment, and generation. The second, presented in “MAXIM: Multi-Axis MLP for Image Processing” at CVPR 2022, is based on a UNet-like architecture and achieves competitive performance on low-level imaging tasks including denoising, deblurring, dehazing, deraining, and low-light enhancement. To facilitate further research on efficient Transformer and MLP models, we have open-sourced the code and models for both MaxViT and MAXIM.

A demo of image deblurring using MAXIM frame by frame.

Overview
Our new approach is based on multi-axis attention, which decomposes the full-size attention (each pixel attends to all the pixels) used in ViT into two sparse forms — local and (sparse) global. As shown in the figure below, the multi-axis attention contains a sequential stack of block attention and grid attention. The block attention works within non-overlapping windows (small patches in intermediate feature maps) to capture local patterns, while the grid attention works on a sparsely sampled uniform grid for long-range (global) interactions. The window sizes of grid and block attentions can be fully controlled as hyperparameters to ensure a linear computational complexity to the input size.

The proposed multi-axis attention conducts blocked local and dilated global attention sequentially followed by a FFN, with only a linear complexity. The pixels in the same colors are attended together.

Such low-complexity attention can significantly improve its wide applicability to many vision tasks, especially for high-resolution visual predictions, demonstrating greater generality than the original attention used in ViT. We build two backbone instantiations out of this multi-axis attention approach – MaxViT and MAXIM, for high-level and low-level tasks, respectively.

MaxViT
In MaxViT, we first build a single MaxViT block (shown below) by concatenating MBConv (proposed by EfficientNet, V2) with the multi-axis attention. This single block can encode local and global visual information regardless of input resolution. We then simply stack repeated blocks composed of attention and convolutions in a hierarchical architecture (similar to ResNet, CoAtNet), yielding our homogenous MaxViT architecture. Notably, MaxViT is distinguished from previous hierarchical approaches as it can “see” globally throughout the entire network, even in earlier, high-resolution stages, demonstrating stronger model capacity on various tasks.

The meta-architecture of MaxViT.

MAXIM
Our second backbone, MAXIM, is a generic UNet-like architecture tailored for low-level image-to-image prediction tasks. MAXIM explores parallel designs of the local and global approaches using the gated multi-layer perceptron (gMLP) network (patching-mixing MLP with a gating mechanism). Another contribution of MAXIM is the cross-gating block that can be used to apply interactions between two different input signals. This block can serve as an efficient alternative to the cross-attention module as it only employs the cheap gated MLP operators to interact with various inputs without relying on the computationally heavy cross-attention. Moreover, all the proposed components including the gated MLP and cross-gating blocks in MAXIM enjoy linear complexity to image size, making it even more efficient when processing high-resolution pictures.

Results
We demonstrate the effectiveness of MaxViT on a broad range of vision tasks. On image classification, MaxViT achieves state-of-the-art results under various settings: with only ImageNet-1K training, MaxViT attains 86.5% top-1 accuracy; with ImageNet-21K (14M images, 21k classes) pre-training, MaxViT achieves 88.7% top-1 accuracy; and with JFT (300M images, 18k classes) pre-training, our largest model MaxViT-XL achieves a high accuracy of 89.5% with 475M parameters.

Performance comparison of MaxViT with state-of-the-art models on ImageNet-1K. Top: Accuracy vs. FLOPs performance scaling with 224×224 image resolution. Bottom: Accuracy vs. parameters scaling curve under ImageNet-1K fine-tuning setting.

For downstream tasks, MaxViT as a backbone delivers favorable performance on a broad spectrum of tasks. For object detection and segmentation on the COCO dataset, the MaxViT backbone achieves 53.4 AP, outperforming other base-level models while requiring only about 60% the computational cost. For image aesthetics assessment, the MaxViT model advances the state-of-the-art MUSIQ model by 3.5% in terms of linear correlation with human opinion scores. The standalone MaxViT building block also demonstrates effective performance on image generation, achieving better FID and IS scores on the ImageNet-1K unconditional generation task with a significantly lower number of parameters than the state-of-the-art model, HiT.

The UNet-like MAXIM backbone, customized for image processing tasks, has also demonstrated state-of-the-art results on 15 out of 20 tested datasets, including denoising, deblurring, deraining, dehazing, and low-light enhancement, while requiring fewer or comparable number of parameters and FLOPs than competitive models. Images restored by MAXIM show more recovered details with less visual artifacts.

Visual results of MAXIM for image deblurring, deraining, and low-light enhancement.

Summary
Recent works in the last two or so years have shown that ConvNets and Vision Transformers can achieve similar performance. Our work presents a unified design that takes advantage of the best of both worlds — efficient convolution and sparse attention — and demonstrates that a model built on top, namely MaxViT, can achieve state-of-the-art performance on a variety of vision tasks. More importantly, MaxViT scales well to very large data sizes. We also show that an alternative multi-axis design using MLP operators, MAXIM, achieves state-of-the-art performance on a broad range of low-level vision tasks.

Even though we present our models in the context of vision tasks, the proposed multi-axis approach can easily extend to language modeling to capture both local and global dependencies in linear time. Motivated by the work here, we expect that it is worthwhile to study other forms of sparse attention in higher-dimensional or multimodal signals such as videos, point clouds, and vision-language models.

We have open-sourced the code and models of MAXIM and MaxViT to facilitate future research on efficient attention and MLP models.

Acknowledgments
We would like to thank our co-authors: Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, and Alan Bovik. We would also like to acknowledge the valuable discussion and support from Xianzhi Du, Long Zhao, Wuyang Chen, Hanxiao Liu, Zihang Dai, Anurag Arnab, Sungjoon Choi, Junjie Ke, Mauricio Delbracio, Irene Zhu, Innfarn Yoo, Huiwen Chang, and Ce Liu.

Read More