An Annotated Vision Transformer
Pytorch implementation of "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (2021)
In today's digital age, the adage "A picture is worth a thousand words" has taken on a new dimension with the advent of deep learning. The recent paper, "An Image is Worth 16x16 Words", introduces the Vision Transformer (ViT) - a paradigm-shifting approach that treats an image not as a grid of pixels, but as a sequence of tokens, similar to how natural language models treat words. This shift from the conventional convolutional perspective to a sequence-based representation challenges our established norms and offers exciting new avenues in the realm of computer vision.
In this blog post, we'll delve deep into the construction of the Vision Transformer, annotating each step with PyTorch code for a clearer, hands-on understanding. You can find a PyTorch notebook version here:
Let’s start with imports
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from torchsummary import summaryWe also need a sample image to work with. I've been obsessed with my sister's cotton du'tulier lately, so here he is
!curl -O https://www.mvtimes.com/mvt/uploads/2019/01/coton-de-tulear-chien-blanc-femelle-1496325326lJt.jpg?x76679
img = Image.open('/content/coton-de-tulear-chien-blanc-femelle-1496325326lJt.jpg?x76679')
fig = plt.figure()
plt.imshow(img)Okay, now let’s dive into the paper
From the above paragraph, the key takeaway is that instead of processing 2D images with convolution like in a traditional Convolutional Neural Network (CNN), the ViT approach divides the image into patches, flattens them, and then processes this sequence of flattened patches using the Transformer architecture. This "sequence of patches" becomes analogous to a "sequence of words" in the original Transformer model from the paper "Attention is All You Need" (2017). Let’s implement this in Pytorch. This first step is to preprocess our image.
# resize to imagenet size
transform = Compose([Resize((224,224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # add a batch dimension: [1, 3, 224, 224]
x.shapetorch.Size([1, 3, 224, 224])
Our original image, x, is represented in a typical format of height H, width W, and C channels (like RGB for color images, C is 3).
The next step is to divide the image into multiple smaller patches, each of size P x P. Each of these patches is then flattened (or unrolled) into a single vector.
P = 16
x = x.reshape(1, 3, 224 // P, P, 224 // P, P) # [batch size, channels, height in patches, height of a patch, width in patches, width of a patch]torch.Size([1, 196, 768])
Great! Let’s implement our Patch Embedding Layer next.
Now, we need to transform or embed each flattened patch into a fixed size D using a learned linear transformation. The resulting sequence of vectors (one for each patch) is then processed by the subsequent transformer layers. This is the same as the linear transformation achieved using the nn.Linear layer in PyTorch. Except that the authors used a conv2d layer instead of a linear one for performance gain. The key here is to set the depth of the convolutional kernel to determine the embedding size. For example, if you want each patch to be embedded into a size of 768 (like the D in the Vision Transformer), you'd use a Conv2d layer with an input depth of 3 (for RGB images) and an output depth of 768.
Two more things before we begin implementing our Patch Embedding Layer.
First, a special CLS token is added to the beginning of this sequence. This token does not represent any part of the image directly but is used to aggregate global information about the image. A cls token is a torch Parameter randomly initialized, in the forward method, it is copied b(batch) times and prepended before the projected patches using torch.cat
Second, positional embeddings are added to both the image patch embeddings and the CLS token to give the model information about the relative or absolute position of each patch in the image. We use sine and cosine functions to generate positional encodings, creating a smooth and continuous representation of position. In ViT, we let the model learn it. In a vanilla transformer. The positional encoding is just a tensor of shape (N_patches + 1(token)), EMBED_SIZE that is added to the projected patches
Let's implement this Patch Embedding class to embed our image patches
import torch
from torch import nn, Tensor
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
"""
Initialize the Patch Embedding layer.
Args:
- in_channels: Number of input channels.
- patch_size: Size of each image patch.
- emb_size: Embedding size for each patch.
- img_size: Size of input image.
"""
super(PatchEmbedding, self).__init__()
self.patch_size = patch_size
# Projection layer
self.projection = nn.Sequential(
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
)
# CLS token
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
# Compute grid size and total patches
grid_size = img_size // patch_size
num_patches = grid_size * grid_size
# Create learnable positional encoding
self.positions = nn.Parameter(torch.randn(num_patches + 1, emb_size)) #accounting for the additional CLS token
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass for Patch Embedding.
Args:
- x: Input tensor of shape (batch_size, in_channels, img_size, img_size)
Returns:
- Tensor of shape (batch_size, num_patches + 1, emb_size)
"""
b, _, _, _= x.shape
x = self.projection(x)
# Reshape the tensor
x = x.flatten(2).transpose(1, 2)
# Expand the cls token tensor along the batch dimension
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # prepend the cls token
x += self.positions # add position embedding
return x
PatchEmbedding()(x).shapetorch.Size([1, 197, 768])
Now we need to implement the Transformer. In ViT, only the encoder is used, the architecture is visualized in the following picture
Let's break down the components
Norm: This refers to the layer normalization.
Multi-Head Attention: This is the multi-head self-attention mechanism.
MLP: This is a feed-forward neural network consisting of two linear layers with a GELU (or sometimes ReLU) activation in between.
Residuals: The transformer block has residual connections to address the vanishing gradient problem.
First, let's create a nice Residual Connection to keep things neat and organized. Residual connections were introduced to address the vanishing gradient problem. Instead of directly mapping input to output within a DNN layer, a residual block includes a skip connection that directly adds the input to the output.
This might be a good time to mention dropout, which is a regularization technique used in deep learning to improve the generalization and robustness of neural networks. It involves randomly deactivating (i.e., "dropping out") a subset of neurons during each training iteration to prevent overfitting
class ResidualConnection(nn.Module):
def __init__(self, d_model: int, dropout: float):
"""
Initialize the Residual Connection layer
Args:
- d_model: Dimensions of the model
- dropout: Dropout probability
"""
super(ResidualConnection, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x:Tensor, sublayer):
"""
Args:
- x: Input tensor of shape (batch_size, seq_length, d_model)
- sublayer: Sub-layer function (e.g., self-attention or feed-forward)
Returns:
- Tensor of shape(batch_size, seq_length, d_model)
"""
# Apply normalization and dropout on the input tensor
normalized_x = self.norm(x)
dropped_x = self.dropout(normalized_x)
# Pass the processed input through the sublayer
output = sublayer(dropped_x)
# Add the original input to the sublayer's output (residual connection)
output += x
return output
class TransformerEncoderBlock(nn.Module):
def __init__(self,
d_model: int,
nhead: int,
num_hidden_dim: int,
dropout: float = 0.1) -> None:
"""
Initialize the Transformer Encoder Block.
Args:
- d_model: Dimension of the model.
- nhead: Number of attention heads in the multi-head self-attention.
- num_hidden_dim: Hidden dimension in the feed forward network.
- dropout: Dropout probability.
"""
super(TransformerEncoderBlock, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, num_hidden_dim),
nn.GELU(),
nn.Linear(num_hidden_dim, d_model)
)
self.residual1 = ResidualConnection(d_model, dropout)
self.residual2 = ResidualConnection(d_model, dropout)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass for the Transformer Encoder Block.
Args:
- x: Input tensor of shape (batch_size, seq_length, d_model).
Returns:
- Tensor of shape (batch_size, seq_length, d_model).
"""
# Multi-Head Self Attention with residual connection
q = x # Queries
k = x # Keys
v = x # Values
attn_output, _ = self.self_attn(q, k, v)
x = self.residual1(x, lambda x: attn_output)
# Feed-forward neural network with residual connection
ff_output = self.feed_forward(x)
x = self.residual2(x, lambda x: ff_output)
return x
The last layer is a normal fully connected MLP that gives the class probability
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))And now putting it all together. We can compose PatchEmbedding, TransformerEncoder, and ClassificationHead to create the final ViT architecture
class ViT(nn.Module):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
nhead=8,
num_hidden_dim=3072,
**kwargs):
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
self.encoder_blocks = nn.ModuleList([TransformerEncoderBlock(d_model=emb_size, nhead=nhead, num_hidden_dim=num_hidden_dim, **kwargs) for _ in range(depth)])
self.classification_head = ClassificationHead(emb_size, n_classes)
def forward(self, x):
x = self.patch_embedding(x)
for block in self.encoder_blocks:
x = block(x)
x = self.classification_head(x)
return xNow, let’s put it together!
# start with an image like before
transform = Compose([Resize((224,224)), ToTensor()])
x = transform(img)
x = x.unsqueeze(0) # add a batch dimension: [1, 3, 224, 224]
x.shapevit_model = ViT()
output = vit_model(x)torch.Size([1, 197, 1000])
Final Thoughts
In conclusion, Vision Transformers (ViT) have demonstrated a remarkable departure from the traditional Convolutional Neural Networks (CNNs) by embracing less inductive bias and relying on self-attention mechanisms. These architectural choices have led to compelling experimental results, as presented in the groundbreaking paper 'An Image is Worth 16x16 Words.'
The experiments conducted in this seminal work showcased ViT's state-of-the-art performance across a wide spectrum of computer vision tasks. Notably, ViT achieved remarkable results in image classification, object detection, and semantic segmentation tasks, surpassing the capabilities of its CNN counterparts.
Furthermore, ViT's scalability and adaptability to varying data sizes and resolutions have been pivotal in its success. Unlike CNNs, which require manual architectural adjustments for different image sizes, ViT has demonstrated the ability to process images of diverse resolutions with consistent efficacy. This versatility positions ViT as a robust and flexible choice for computer vision applications.
The pivotal 'patch-based' approach introduced by ViT, where the input image is divided into fixed-size patches and processed through the self-attention mechanism, has proven to be highly effective. It allows ViT to capture long-range dependencies within the image and maintain spatial information without the need for convolutional operations.
These remarkable experimental findings collectively underscore ViT's ascent as the new state-of-the-art in computer vision. ViT has not only expanded the horizons of image understanding but has also set a foundation for future research and development in the field of deep learning, promising groundbreaking advancements and applications in computer vision.












