This series aims to explain the mechanism of Vision Transformers (ViT) [2], which is a pure Transformer model used as a visual backbone in computer vision tasks. It also points out the limitations of ViT and provides a summary of its recent improvements.
1. What is transformer?
Transformer networks [1] are sequence transduction models, referring to models transforming input sequences into output sequences. The transformer networks, comprising of an encoder-decoder architecture, are solely based on attention mechanisms. We will be discussing attention mechanisms in more detail in the following sections. However, let's first briefly go through some of the previous approaches.
Transformers were first introduced by [1] for the task of machine translation, referring to the conversion of a text sequence in one language into another language. Before the discovery of this breakthrough architecture, deep neural architectures such as recurrent neural network (RNN) and convolutional neural networks (CNN) have been used extensively for this task. RNNs generate a sequence of hidden states based on the previous hidden states and current input. The longer the sentences are, the lower relevance to the words far away is left. However, in languages, linguistic meaning holds more relevance as compared to proximity. For example, in the following sentence:
"Jane is a travel blogger and also a very talented guitarist."
the word "Jane" has more relevance to "guitarist" than the words "also" or "very". While LSTMs, a special kind of RNNs, learn to incorporate important information and discard irrelevant information, they also suffer from long-range dependencies. Moreover, the dependence of RNNs on previous hidden states and required sequential computation does not allow for parallelization. The transformer models solve these problems by using the concept of attention.
As mentioned previously, the linguistic meaning and context of words are more relevant than words being in close proximity. It is also important to note that every word in a sentence might be relevant to some other word in the sentence and this needs to be taken into account. The attention module in the transformer models aims to do just this. The attention module takes as input the query, keys, and value vectors. The idea is to compute a dot product between a word (query) with every word (key) in the sentence. These provide us with weights on the relevance of the key to the query. These weights are then normalized and softmax is applied. A weighted sum is then computed by applying these weights to the corresponding words in the sentence (value), to provide a representation of the query word with more context.
It is worth noting that these operations are performed on the vector representation of the words. The words can be denoted as a meaningful representation vector, their word embeddings of N-dimension. However, as transformer networks support parallelism, a positional encoding of the word is incorporated to encode the position of the word in the sentence. The positional information is important in many scenarios, for example, correct word ordering gives meaning to the sentence. As self-attention operation in transformers is permutation-invariant, the positional information is introduced by appending a positional encoding vector to the input embedding. This vector captures the position of the word in the input sentence and helps to differentiate words that appear more than once. The positional embedding can either be learned embedding or pre-defined sinusoidal functions of different frequencies. A detailed empirical study of position embedding in NLP can be found here [10]. Similarly, in vision transformers the use of positional embedding is to leverage positional information in the input sequence.
The overview of the transformer is shown in Fig. 1. An input sequence is fed into the transformer encoder (the left part of the figure), which consists of N encoder layers. Each encoder layer consists of 2 sublayers: 1) multi-head self-attention and 2) position-wise feedforward network (PFFN). Residual connection and layer normalization are then applied to both sublayers. The multi-head attention aims to find the relationship between tokens in the input sequence in various different contexts. Each head computes attention by linearly projecting each token into query, key, and value vectors. The query and key are then used to compute the attention weight which is applied to the value vectors. The output from the multi-head attention sublayer (the same size as its input) is then fed into PFFN to further transform the representation of the input sequence. This process is repeated N times by N encoder layers.
Figure 1. The architecture of the transformer model (image from [1])
The right part of the figure is the transformer decoder, similarly consisting of N decoder layers, attached by a prediction head (the linear layer with softmax). Each decoder layer consists of 3 sublayers: 1) multi-head self-attention, 2) multi-head cross-attention, and 3) PFFN. The first and the third are similar to those of the encoder layers. The second sublayer, i.e., multi-head cross-attention, computes the relationship between each token in its input sequence and each token in the output generated by the encoder. In particular, as shown in the figure, the transformer decoder receives 2 inputs: 1) a sequence of tokens fed into the bottom of the decoder and 2) the output from the transformer encoder. In the original paper of the transformer model [1], in which machine translation was considered, the output sequence generated by the prediction head is fed into the decoder as input. Note that in other applications, e.g., computer vision, an extra fixed or learnable sequence can be used as input to the decoder.
More information about the transformer model can be found in [16], [17]. This 4-video series on attention is also highly recommended [18].
2. Vision transformer (ViT)
The transformer model and its variants have been successfully shown that they can be comparable or even better than the state-of-the-art in several tasks, especially in the field of NLP. This section briefly explores how the transformer model could be applied to computer vision tasks and then introduces a transformer model, vision transformer (ViT), which gains massive attention from many researchers in the field of computer vision.
Several attempts have been made to apply attention mechanisms or even the transformer model to computer vision. For example, in [11], a form of spatial attention, in which the relationship between pixels is computed, has been used as a building block in CNNs. This mechanism allows CNNs to capture long-range dependencies in an image and better understand the global context. In [12], a building block called squeeze-and-excitation (SE) block, which computes the attention in the channel dimension, was proposed to improve the representation power of CNNs.
On the other hand, the combination between the transformer model and CNN has been proposed to solve computer vision tasks such as object detection or semantic segmentation. In the detection transformer (DETR) [13], a transformer model was used to process the feature map generated by a CNN backbone to perform object detection. The use of a transformer model in DETR removes the need for hand-designed processes such as non-maximal suppression and allows the model to be trained end-to-end. Similar ideas have been proposed in [14] and [15] to perform semantic segmentation.
Distinct from those works in which a transformer or attention modules are used as a complement to CNN models to solve vision tasks, vision transformer (ViT) [2] is a convolution-free, pure transformer architecture proposed to be an alternative visual backbone. The overall network architecture of ViT is shown in Fig. 2.
Figure 2. The architecture of ViT (image from [2])
A key idea of applying a transformer to image data is how to convert an input image into a sequence of tokens, which is usually required by a transformer. In ViT, an input image of size H x W is divided into N non-overlapping patches of size 16 x 16 pixels, where N = (H x W) / (16 x 16). Each patch is then converted into an embedding using a linear layer. These embeddings are grouped together to construct a sequence of tokens, where each token represents a small part of the input image. An extra learnable token, i.e., classification token, is prepended to the sequence. It is used by the transformer layers as a place to pull attention from other positions to create a prediction output. Positional embeddings are added to this sequence of N + 1 tokens and then fed into a transformer encoder.
As shown in Figs. 1 and 2, the transformer encoder in ViT is similar to that in the original transformer by Vaswani et al. [1]. The only difference is that in ViT, layer normalization is done before multi-head attention and MLP while Vaswani’s transformer performs normalization after those processes. This pre-norm concept is shown by [19], [20] to lead to efficient training with deeper models.
The output of the transformer encoder is a sequence of tokens of the same size as the input, i.e., N + 1 tokens. However, only the first, i.e., the classification token, is fed into a prediction head, which is a multi-layer perception (MLP), to generate a predicted class label.
In [2], ViT was pre-trained on large-scale image datasets such as ImageNet-21k, which consists of 14M images of 21k classes, or JFT, which consists of 303M high-resolution images of 18k classes, and then fine-tuned on several image classification benchmarks. Experimental results showed that when pre-trained with a large amount of image data, ViT achieved competitive performance compared to state-of-the-art CNNs while being faster to train.
3. Problems of ViT and its improvement
Although ViT can gain a lot of attention from researchers in the field, many studies have pointed out its weaknesses and proposed several techniques to improve ViT. The following subsections describe the key problems in the original ViT and introduce recently published papers that aim to cope with the problems.
3.1 A requirement of a large amount of data for pre-training
Pre-training seems to be a key ingredient in several transformer-based networks; however, as shown in the paper of the original ViT [2] and in other succeeding papers [3], [7]. ViT requires a very large amount of image data to pre-train in order to achieve a competitive performance, compared with CNNs. As reported in [2], pre-training ViT on the ImageNet-21k (21k classes and 14M images) or JFT-300M (18k classes and 303M high-resolution images) could lead to such a performance, while the ImageNet-1k (1k classes and 1.3M images) could not. However, pre-training a ViT on those large-scale datasets would consume an extremely long computational time and high computing resources. Moreover, the JFT-300M dataset is an internally used Google dataset, which is not publicly available.
Some approaches have been proposed so far to handle this problem. For example, in [3], a knowledge distillation technique with a minimal modification of the ViT architecture was adopted in the training process; in [7], a more effective tokenization process to represent an input image was proposed; or in [4] some modifications in the architecture of ViT were explored. The details of these approaches are explained in the following subsections.
3.1.1 DeiT
An idea to improve the training process of ViT is to exploit knowledge distillation as proposed in [3]. Knowledge distillation aims to transfer knowledge from a bigger model, i.e., a teacher network, to a smaller, target model, i.e., a student network. In [3], they slightly modify the architecture of ViT by appending another extra token called a distillation token, as shown in Fig. 3. The modified ViT, named data-efficient image transformer or DeiT, generates two outputs: one at the position of the classification token which is compared with a ground truth label, and another at the position of the distillation token which is compared with the logit output from the teacher network. The loss function is computed from both outputs, which allows the model to leverage the knowledge from the teacher while also learning from the ground truths. They also incorporate some bag of tricks including data augmentation and regularization to further improve the performance. With this technique, they reported that DeiT could be trained a single 8-GPU node in 3 days (53 hours for pre-training and 23 hours for optional fine-tuning) while the original ViT required 30 days to train with an 8-core TPUv3 machine.
Figure 3: Distillation process in DeiT (image from [3])
3.1.2 CaiT
Class-attention in image transformer (CaiT), a modified ViT proposed in [4], has been shown to be able to train on the ImageNet-1k dataset while achieving competitive performance. CaiT is different from ViT in 3 points. First, it utilizes a deeper transformer, which aims to improve the representational power of features. Second, a technique called LayerScale is proposed to facilitate the convergence of training the deeper transformer. LayerScale introduces a learnable, per-channel scaling factor, which is inserted after each attention module to stabilize the training of deeper layers. This technique allows CaiT to gain benefit from using the deeper transformer, while there is no evidence of improvement when increasing the depth in ViT or DeiT. Third, CaiT applies different types of attention at different stages of the network: the normal self-attention (SA) in the early stage and class-attention (CA) in the later stage. The reason is to separate two tasks with contradictory objectives from each other. As shown in Fig. 4 (right), the class token is inserted after the first stage, which is different from ViT. This allows the SA to focus on associating each token to each other, without the need of summarizing the information for the classification. Once the class token is inserted, the CA, then, integrates all information into it to build a useful representation for the classification step.
Figure 4: Architecture comparison between ViT (left), a modified ViT in which the class token is inserted in a later stage (middle), and CaiT (right).
3.1.3 Tokens-to-Token ViT
The authors of [7] believed that the following are the key reasons that ViT requires pre-training on a large-size dataset. The first reason is that the simple tokenization process in ViT cannot well capture important local structures in an input image. The local structures such as edges or lines often appear in several neighboring patches, rather than one; however, the tokenization process in ViT simply devices an image into non-overlapping patches, and independently converts each into an embedding. The second reason is that the transformer architecture used in the original ViT was not well-designed and optimized, leading to redundancies in the feature maps.
To cope with the first problem, they proposed a tokenization method, named Tokens-to-token (T2T) module, that iteratively aggregates neighboring tokens into one token using a process named T2T process, as shown in Fig. 5. The T2T process can be done as follows:
A sequence of tokens is passed into a self-attention module to improve the relation between tokens. The output of this step is another sequence of the same size as its input.
The output sequence from the previous step is reshaped back into a 2D-array of tokens.
The 2D-array of tokens is then divided into overlapping windows, in which neighboring tokens in the same window are concatenated into a longer token. The result of this process is a shorter 1D-sequence of higher-dimensional tokens.
The T2T process can be iterated to better improve the representation of the input image. In [7], it was done twice in the T2T module.
Figure 5: The Tokens-to-token process (image from [7])
Apart from using the proposed T2T module to improve the representation of an input image, they also explored various architecture designs used in CNNs and applied them to the transformer backbone. They found a deep-narrow structure, which exploits more transformer layers (deeper) to improve feature richness and reduces the embedding dimension (narrower) to maintain the computational cost, gave the best results among the compared architecture designs. As shown in Fig. 6, the sequence of tokens generated by the T2T module is prepended with a classification token, as in the original ViT, and is then fed into the deep-narrow transformer, which is named T2T-ViT backbone, to make a prediction.
It is shown in [7] that when trained from scratch, T2T-ViT outperforms the original ViT on the ImageNet1k dataset while reducing the model size and the computation cost by half.
Figure 6: The overview architecture of T2T-ViT (image from [7])
3.2 High computational complexity, especially for dense prediction in high-resolution images
Besides the requirement of pre-training on a very large-scale dataset, the high computational complexity of ViT is another concern since its input is an image, which contains a large amount of information. To better exploit the attention mechanism to an image input, pixel-level tokenization, i.e., to convert each pixel into a token, seems to be the best case; however, the computational complexity of the attention module, which is quadratic to the image size, leads to the intractable problem of high computational complexity and memory usage. Even in [2], in which images of a normal resolution were experimented with, non-overlapping patches of size 16x16 were chosen which could reduce the complexity of the attention module by a factor of 16x16. This problem is worse when ViT is applied to be a visual backbone for a dense prediction task such as object detection or semantic segmentation since high-resolution image inputs would be preferable to achieve a competitive performance with the state-of-the-art.
Several approaches to this problem mainly aim at improving the efficiency of the attention module. The following subsections describe two examples of the approaches applying to ViT [5], [9].
3.2.1 Spatial-reduction attention (SRA)
Spatial-reduction attention or SRA was proposed in [5] to speed up the computation of pyramid vision transformer (PVT). As shown in Fig. 7, SRA reduces the dimension of the key (K) and value (V) matrices by a factor of Ri2., where i indicates the stage in the transformer model. The spatial reduction consists of 2 steps: 1) concatenating neighboring tokens with a dimension Ci in a non-overlapping window of size Ri x Ri into a token of size Ri2Ci, and 2) linearly projecting each of the concatenated tokens to a token of dimension Ci and performing normalization process. The time and space complexity decrease because the number of tokens is reduced by the spatial reduction.
Figure 7: Comparison between the regular attention (left) and SRA (right) (Image from [5])
3.2.2 FAVOR+
FAVOR+, standing for fast attention via positive orthogonal random feature, was proposed in [9] as a key module of a transformer architecture named Performer. FAVOR+ aims to approximate the regular attention with a linear time and space complexity. The OR+ part in FAVOR+ was done by projecting the queries and keys onto a positive orthogonal random feature space. The FA-part in FAVOR+ was done by changing the order of computation in the attention module, shown in Fig. 8. In a regular attention module, the query and the key matrices are firstly multiplied (which requires quadratic time and space complexity), followed by multiplication with the value matrix. FAVOR+, on the other hand, approximates the regular attention by firstly multiplying the key with the value matrices, followed by the left-multiplication with the query matrix. This results in linear time and space complexity, as shown in Fig. 8.
The Performer architecture, which exploits FAVOR+ inside, was explored in the T2T-ViT [7] and was found competitive in the performance, compared with the original transformer, while reducing the computation cost.
Figure 8: The computation order in FAVOR+ (right), compared with that in a regular attention module (left) (Image from [9])
3.3 Incapability of generating multi-scale feature maps
By design, ViT, which simply uses the original transformer encoder to process image data, can only generate a feature map of a single scale. However, the significance of multi-scale feature maps has been demonstrated in several object detection and semantic segmentation approaches. Since, in those tasks, objects of various scales (small-, mid-, or large-sizes) may appear in the same image, the use of a single-scale of feature maps might not be able to effectively detect all of the objects. Usually, large objects can be easily detected at a rough scale of the image, while small objects are often detected at a finer scale.
Several papers proposed to modify the architecture of ViT to generate multi-scale feature maps and demonstrated their effectiveness in object detection and segmentation tasks [5], [8].
3.3.1 Pyramid vision transformer (PVT)
Pyramid Vision Transformer (PVT) [5] was proposed as a pure transformer model (convolution-free) used to generate multi-scale feature maps for dense prediction tasks, like detection or segmentation. PVT converts the whole image to a sequence of small batches (4x4 pixels) and embeds it using a linear layer (patch embedding module in Fig. 9). At this stage, the size of the input is spatially reduced. This embedding is then fed to a series of transformer encoders to generate the first-level feature map. Next, this process is repeated to generate higher-level feature maps.
It has been shown to be superior to CNN backbones with a similar computing cost on classification, detection, and segmentation tasks. Comparing to ViT, PVT is more suitable in terms of memory usage and achieves higher prediction performance on dense prediction tasks which require higher resolution images and smaller patch size.
Figure 9: The overview architecture of PVT (Image from [5])
3.3.2 Swin transformer
Another approach for generating multi-scale feature maps, named Swin transformer, was proposed in [8]. As shown in Fig. 10, Swin transformer can progressively produce feature maps with a smaller resolution while increasing the number of channels in the feature maps. Note that Swin transformer adopts a smaller patch size of 4 x 4 pixels, while the patch size of 16 x 16 pixels is used in the original ViT. The key module that changes the resolution of a feature map is the patch merging module at the beginning of each stage, except stage 1. Let C denote the dimension of an embedding output of stage 1. The patch merging module simply concatenates embedding representing each patch in a group of 2 x 2 patches, resulting in a 4C-dimensional embedding. A linear layer is then used to reduce the dimension to 2C. The number of embeddings after patch merging is reduced by a factor of 4, which is the group size. The merge embeddings are then processed by a sequence of transformers, named Swin transformer blocks. This process is then repeated in the following stages to generate a smaller-resolution feature map. The outputs from all stages form a pyramid of feature maps representing features in multi-scales. Note that Swin transformer follows the architectural design of several CNNs, in which the resolution is reduced by a factor of 2 on each side while doubling the channel dimension when going deeper.
Figure 10: The overview architecture of Swin transformer (Image from [8])
A Swin transformer block shown in Fig. 10 consists of two transformer layers: the first with a window-based MSA (W-MSA) module and the second with shifted-window MSA (SW-MSA) module. Both W-MSA and SW-MSA compute self-attention locally within each non-overlapping window, i.e., a group of neighboring patches), as shown in Fig. 11. The difference between W-MSA and SW-MSA is that the grid of windows is shifted by half of the window size. On the one hand, by limiting the attention to be inside the window, the computational complexity is linear if the window size is fixed. On the other hand, it destroys a key property of attention in ViT in which each patch can be globally associated with each other in one attention process. The Swin transformer solves this problem by alternating between W-MSA and SW-MSA in two consecutive layers, allowing the information to propagate to a larger area when going deeper.
Experimental results in [8] showed that the Swin transformer outperformed ViT, DeiT, ResNe(X)t in three main vision tasks, i.e., classification, detection, segmentation.
Figure 11: An illustration of how W-MSA and SW-MSA compute self-attention locally (left) and the architecture of the Swin transformer block (right) (Images from [8])
3.3.3 Pooling-based vision transformer (PiT)
A design principle of CNNs, in which as the depth increases, the spatial resolution decreases while the number of channels increases, has been widely used in several CNN models. Pooling-based vision transformer (PiT) proposed in [6] has shown that the design principle is also beneficial to ViT. As shown in Fig. 12, in the first stage of pooling-based vision transformer (PiT) [6], the input sequence of tokens is processed in the same as ViT. However, after each stage, the output sequence is reshaped into an image which is then reduced in the spatial resolution by a depthwise convolution layer with a stride of 2 and 2C filers where C is the number of input channels. The output is then reshaped back into a sequence of tokens and passed to the following stage.
Figure 12: Comparison in the network architecture: ViT (left) and PiT (right). (Images from [6])
Experimental results showed that the proposed pooling layer significantly improved the performance of ViT on image classification and object detection tasks. Although the experiments in [6] did not explore the use of multi-scale feature maps in those tasks, by the design of PiT, it is obviously capable of constructing a multi-scale feature map as in other models explained in this section.
4. Summary
This page aims at summarizing the key technical details of ViT, i.e., a pure transformer backbone which gains a lot of attention from researchers in the field. This page also points out the key problems in ViT and introduces recent research papers that made attempts to address the problems. Due to its impressive performance, ViT and its variants could be considered as a promising visual backbone for several vision tasks such as classification, object detection, or semantic segmentation, while there is room for further improvements.
References
[1] A. Vaswani, N. Shazeer, N. Paramr, J. Uszkoreit, L. Jones, A.N. Gomez, et al., “Attention is all you need,” Proceedings of 31st International Conference on Neural Information Processing Systems (NIPS 2017), 2017. (Original Transformer)
[8] Z. Liu, Y. Lin, Y. Cao, H. Hu, Y. Wei, Z. Zhang, et al., “Swin transformer: Hierarchical vision transformer using shifted windows,” arXiv Preprint, arXiv2103.14030, 2021. (Swin Transformer)
[9] K. Choromanski, V. Likhosherstov, D Dohan, X. Song, A. Gane, T. Sarlos, et al., “Rethinking attention with performers,” arXiv Preprint, arXiv2009.14974, 2020. (Performer and FAVOR+)
[10] Y.-A. Wang and Y.-N. Chen, “What do position embeddings learn? An empirical study of pre-trained language model positional encoding,” arXiv Preprint, arXiv2010.04903, 2020. (Positional Embedding study, NLP)
[11] X. Wang, R. Girshick, A. Gupta, and K. He, “Non-local neural networks,” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. (Non-local NN)
[16] The Annotated Transformer
[17]
[18] Rasa Algorithm Whiteboard - Transformers & Attention
[19] Learning Deep Transformer Models for Machine Translation https://www.aclweb.org/anthology/P19-1176.pdf
[20] Adaptive Input Representations for Neural Language Modeling https://openreview.net/pdf?id=ByxZX20qFQ
Comments