Illustrated Vision Transformers

Introduction

Ever since Transformer was introduced in 2017, there has been a huge success in the field of Natural Language Processing (NLP). Almost all NLP tasks use Transformers and it’s been a huge success. The main reason for the effectiveness of the Transformer was its ability to handle long-term dependencies compared to RNNs and LSTMs. After its success in NLP, there have been various approaches to its usage for Computer Vision tasks. This paper An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al. proposes using the transformer and has achieved some great results in various Computer Vision tasks.

Vison Transformer (ViT) makes use of an extremely large dataset while training the model. While training on datasets such as ImageNet (paper labels ImageNet as a mid-sized dataset), the accuracies of the model fall below ResNets. This is because the Transformer lack inductive bias such as translation equivariance and locality, thus it does not generalize well when trained on insufficient data.

Overview of Vision Transformer

  • Split image into patches
  • Provide sequence of linear embeddings of these patches as an input to transformer (flattening the image)
    Here, image patches are treated as the same way as tokens (as in NLP tasks)
  • Add positional embeddings and a learnable embedding class (similar to BERT) to each patch embeddings
  • Pass these (patch + positional + class] embeddings through Transformer encoder and get the output values for each class tokens
  • Pass the representation of class through MLP head and get the final class predictions.

Method

Source: Google AI Blog

Figure above depicts the overview of the Vision Transformer. As shown in the figure, given the image, the image is split into patches. These image patches are flattened and passed to the transformer encoder a sequence of tokens. Along with patch embeddings, position embedding is also passed as an input to the transformer encoder. Here position embedding is added along with patch embedding to retain positional information.

How is an image changed into a sequence of vectors to feed into the transformer encoder?


Let’s decode the above figure by taking an RGB image of size 256 * 256 * 3. The first step is to create patches of size 16 * 16 from the input image. We can create 16 * 16 = 256 total patches. After splitting input images into patches, another step is to linearly place all split images. As seen in the figure, the first patch is placed on the left-most side and rightmost on the far right. Then, we linearly project these patches to get 1 * 768 vector representations. This representation is known as patch embeddings. The size of patch embedding becomes 256 * 768 (since we have 256 total patches with each patch represented as 1 * 768 vector.

Next, we prepend learnable embedding class token and position embeddings along with patch embeddings making the size 257 * 768. Here, position embeddings are used to retain positional information. After converting images into a vector representation, we need to send images in order as the transformer doesn’t know the order of the patches, unlike CNN. Due to this, we need to manually add some information about the position of the patches.

Components of Vision Transformer

Since Vision Transformer is based on standard transformer architecture, the only difference is being used for image tasks rather than for text, components used here are almost the same. Here, we discuss the components used in the Vision transformer along with their significance.

Side note: If you want to dive deep into transformer, then here by Jay Alammar is a good place to start with.

Patch embeddings

As the name of the paper “An Image is worth 16 * 16 words transformers”, the main takeaway of the paper is the breakdown of images into patches. Given the image: x \, \varepsilon \, \mathbb{R}^{(H * W * C)} it is reshaped into 2D flattened patches x_p \, \varepsilon \, \mathbb{R}^{N*(P^2.C))},
where,
N=\frac{H.W}{p^2}, (P, P) is the resolution of each image patch.

Learnable embedding class

A learnable embedding is added to the embedded patches z_0^0 = x_{class}. The state of this embedding class at the output of Transformer encoder z_L^0 serves as the representation y. This classification head is attached to z_L^0 during both pre-training and fine-tuning.

Position Embeddings

Position Embeddings are added to the patch embeddings along with class token which are then fed into the transformer encoder.

Transformer Encoder

The transformer encoder is a standard transformer encoder architecture as presented in the original transformer paper. This encoder takes embedded patches (patch embedding, position embedding, and class embedding). The transformer encoder consists of alternating layers of multiheaded self-attention and MLP blocks. Layer Normalization is used before every block and the residual connection is used after every block.

Using hybrid architecture

Previously, image patches were used to form input sequences, another approach to form input sequences can be the feature map of a CNN (Convolution Neural Network). Here, the patches extracted from the CNN map are used as patch embedding.
From the paper:

As an alternative to raw image patches, the input sequence can be formed
from feature maps of a CNN. In this hybrid model, the patch embedding
projection E (Eq. 1) is applied to patches extracted from a CNN feature map. As a special case,
the patches can have spatial size 1×1, which means that the input sequence is obtained by simply
flattening the spatial dimensions of the feature map and projecting to the Transformer dimension.

References

Default image
prabinnepal