Implementing a Vision Transformer Classifier in PyTorch

Nathan Bailey
11 min readAug 28, 2024

--

Following on from my previous blogs on implementing transformers and GPT models. I decided to look at an extension to regular transformers: the vision transformer. As the name suggests, this type of transformer takes images as input instead of a sequence of words. This blog will overview the architecture of the vision transformer and implement a vision transformer-based classifier on the CIFAR100 dataset.

The full code for this project can be found on my GitHub.

Vision Transformer Architecture

As the name suggests, vision transformers are transformers applied to image data. They are typically encoder-based, meaning that no self-attention mask is needed. Each query vector can use every key vector to generate the attention weights. Unlike a sequence of text, an image is not naturally suited to input into a transformer. Therefore, the main consideration is to decide how we convert an input image into tokens. We could use each pixel of the image as input, but the memory required by a transformer increases quadratically with the number of input tokens and this quickly becomes unfeasible as the spatial size of the image increases. [1]

Instead, we split up the input image into patches of size PxPxC. These patches are then flattened and formed into a matrix of size Nx(P*P*C). We have (H*W)/(P²) patches for an image of size HxW. [1]

Each patch is then fed into the transformer as a token, as well as adding learned positional encoding to the tokens. Typically images in a dataset are of a fixed size, so we have a fixed number of tokens for vision transformers, which contrasts with text-based transformers. [1]

The general architecture of the vision transformer is shown below.

Vision Transformer Architecture [1]

One interesting aspect is the addition of a randomly initialised learnable parameter called the class token that is part of the input. The class token can accumulate information from the other tokens in the sequence as it passes through the transformer layers of the network. It does this via the attention mechanism, which when acting as a query can aggregate information from all the patches through the attention weights [2].

Then classification is performed by inputting the final layer’s class token into a linear layer. If we were to concatenate all tokens together from the final layer and input this into a classification head, it would result in a massive amount of parameters for the network, which would be highly inefficient and unfeasible. Having a separate class token ensures that the vision transformer learns a general representation of the entire sequence into that token and does not bias the final output towards a single token in the sequence. [3] [4] [5]

We can form a vision transformer-based network in a similar way to the GPT model, by stacking vision transformer layers together and adding a classification head at the end. As the vision transformer acts as an encoder, we do not need to worry about any attention masks in our model.

Vision Transformer Implementation

We can implement a simple vision transformer-based model in PyTorch to classify the images from the CIFAR100 dataset. This has been adapted from the excellent Keras guide on vision transformers. [6]

First, let us set up our initial global variables and load the dataset. We will be resizing the images to 72x72 and choosing patches of size 6. This means that we will have (72*72)/(6²) = 144 patches, each of which becomes a token.

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001
BATCH_SIZE = 32
NUM_EPOCHS = 100
IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_LAYERS = 8
MLP_HEAD_UNITS = [2048, 1024]

train_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
torchvision.transforms.RandomRotation(degrees=7),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)

test_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std),
]
)

train_dataset = torchvision.datasets.CIFAR100(
root="./data", train=True, download=True, transform=train_transforms
)

valid_dataset = torchvision.datasets.CIFAR100(
root="./data", train=False, download=True, transform=test_transforms
)

valid_set, test_set = torch.utils.data.random_split(
valid_dataset, [0.7, 0.3], generator=torch.Generator().manual_seed(42)
)

trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=4,
drop_last=True,
)
validloader = torch.utils.data.DataLoader(
valid_set,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=4,
drop_last=True,
)
testloader = torch.utils.data.DataLoader(
test_set,
batch_size=BATCH_SIZE,
shuffle=False,
pin_memory=True,
num_workers=4,
drop_last=True,
)

Model Building Blocks

We will first set up our patch creation layer. This uses the PyTorch unfold layer which generates patches from the spatial dimensions of the image. We then permute the output such that it is in the form (Batch Size, Number of Patches, (P²)*C) where P is our patch size, and C is the number of channels in the image. [7]

class CreatePatchesLayer(torch.nn.Module):
"""Custom PyTorch Layer to Extract Patches from Images."""

def __init__(
self,
patch_size: int,
strides: int,
) -> None:
"""Init Variables."""
super().__init__()
self.unfold_layer = torch.nn.Unfold(
kernel_size=patch_size, stride=strides
)

def forward(self, images: torch.Tensor) -> torch.Tensor:
"""Forward Pass to Create Patches."""
patched_images = self.unfold_layer(images)
return patched_images.permute((0, 2, 1))

We can test this layer in the following code. As can be seen, given an image, our patch layer splits it into 144 individual patches.

batch_of_images = next(iter(trainloader))[0][0].unsqueeze(dim=0)

plt.figure(figsize=(4, 4))
image = torch.permute(batch_of_images[0], (1, 2, 0)).numpy()
plt.imshow(image)
plt.axis("off")
plt.savefig("img.png", bbox_inches="tight", pad_inches=0)
plt.clf()

patch_layer = CreatePatchesLayer(patch_size=PATCH_SIZE, strides=PATCH_SIZE)
patched_image = patch_layer(batch_of_images)
patched_image = patched_image.squeeze()

plt.figure(figsize=(4, 4))
for idx, patch in enumerate(patched_image):
ax = plt.subplot(NUM_PATCHES, NUM_PATCHES, idx + 1)
patch_img = torch.reshape(patch, (3, PATCH_SIZE, PATCH_SIZE))
patch_img = torch.permute(patch_img, (1, 2, 0))
plt.imshow(patch_img.numpy())
plt.axis("off")
plt.savefig("patched_img.png", bbox_inches="tight", pad_inches=0)
Image vs Patched Image

We can then create a patch embedding layer. This simply encodes learned positional embeddings into the patches through the PyTorch embedding layer and concatenates the randomly initialised class token onto the patched data.

class PatchEmbeddingLayer(torch.nn.Module):
"""Positional Embedding Layer for Images of Patches."""

def __init__(
self,
num_patches: int,
batch_size: int,
patch_size: int,
embed_dim: int,
device: torch.device,
) -> None:
"""Init Function."""
super().__init__()
self.num_patches = num_patches
self.patch_size = patch_size
self.position_emb = torch.nn.Embedding(
num_embeddings=num_patches + 1, embedding_dim=embed_dim
)
self.projection_layer = torch.nn.Linear(
patch_size * patch_size * 3, embed_dim
)
self.class_parameter = torch.nn.Parameter(
torch.rand(batch_size, 1, embed_dim).to(device),
requires_grad=True,
)
self.device = device

def forward(self, patches: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
positions = (
torch.arange(start=0, end=self.num_patches + 1, step=1)
.to(self.device)
.unsqueeze(dim=0)
)
patches = self.projection_layer(patches)
encoded_patches = torch.cat(
(self.class_parameter, patches), dim=1
) + self.position_emb(positions)
return encoded_patches

We then finally need a transformer layer. More information on this can be found in the following two articles I wrote:

  1. Transformers Explained
  2. Implementing a GPT Model in PyTorch

Simply put, this layer consists of layer normalisation, a multi-head attention block and a final feed-forward network.

class TransformerBlock(torch.nn.Module):
"""Transformer Block Layer."""

def __init__(
self,
num_heads: int,
key_dim: int,
embed_dim: int,
ff_dim: int,
dropout_rate: float = 0.1,
) -> None:
"""Init variables and layers."""
super().__init__()
self.layer_norm_input = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.attn = torch.nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
kdim=key_dim,
vdim=key_dim,
batch_first=True,
)

self.dropout_1 = torch.nn.Dropout(p=dropout_rate)
self.layer_norm_1 = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.layer_norm_2 = torch.nn.LayerNorm(
normalized_shape=embed_dim, eps=1e-6
)
self.ffn = create_mlp_block(
input_features=embed_dim,
output_features=[ff_dim, embed_dim],
activation_function=torch.nn.GELU,
dropout_rate=dropout_rate,
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
layer_norm_inputs = self.layer_norm_input(inputs)
attention_output, _ = self.attn(
query=layer_norm_inputs,
key=layer_norm_inputs,
value=layer_norm_inputs,
)
attention_output = self.dropout_1(attention_output)
out1 = self.layer_norm_1(inputs + attention_output)
ffn_output = self.ffn(out1)
output = self.layer_norm_2(out1 + ffn_output)
return output

Network Implementation

We then stack these layers together to create our network. Our network consists of first, the patch layer and patch embedding layer and then a stack of transformer blocks. We then take the final embedding of the class token and input this through a series of linear layers to produce the logits required for classification. The resulting network can be shown below.

class ViTClassifierModel(torch.nn.Module):
"""ViT Model for Image Classification."""

def __init__(
self,
num_transformer_layers: int,
embed_dim: int,
feed_forward_dim: int,
num_heads: int,
patch_size: int,
num_patches: int,
mlp_head_units: list[int],
num_classes: int,
batch_size: int,
device: torch.device,
) -> None:
"""Init Function."""
super().__init__()
self.create_patch_layer = CreatePatchesLayer(patch_size, patch_size)
self.patch_embedding_layer = PatchEmbeddingLayer(
num_patches, batch_size, patch_size, embed_dim, device
)
self.transformer_layers = torch.nn.ModuleList()
for _ in range(num_transformer_layers):
self.transformer_layers.append(
TransformerBlock(
num_heads, embed_dim, embed_dim, feed_forward_dim
)
)

self.mlp_block = create_mlp_block(
input_features=embed_dim,
output_features=mlp_head_units,
activation_function=torch.nn.GELU,
dropout_rate=0.5,
)

self.logits_layer = torch.nn.Linear(mlp_head_units[-1], num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward Pass."""
x = self.create_patch_layer(x)
x = self.patch_embedding_layer(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x)
x = x[:, 0]
x = self.mlp_block(x)
x = self.logits_layer(x)
return x
----------------------------------------------------------------------------------------------------------------
Parent Layers Layer (type) Output Shape Param # Tr. Param #
================================================================================================================
ViTClassifierModel/CreatePatchesLayer Unfold-1 [32, 108, 144] 0 0
ViTClassifierModel/PatchEmbeddingLayer Linear-2 [32, 144, 64] 6,976 6,976
ViTClassifierModel/PatchEmbeddingLayer Embedding-3 [1, 145, 64] 9,280 9,280
ViTClassifierModel/TransformerBlock LayerNorm-4 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-5 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-6 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-7 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-8 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-9 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-10 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-11 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-12 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-13 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-14 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-15 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-16 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-17 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-18 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-19 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-20 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-21 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-22 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-23 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-24 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-25 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-26 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-27 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-28 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-29 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-30 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-31 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-32 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-33 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-34 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-35 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-36 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-37 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-38 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-39 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-40 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-41 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-42 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-43 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-44 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-45 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-46 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-47 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-48 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-49 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-50 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-51 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-52 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-53 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-54 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-55 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-56 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-57 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-58 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-59 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-60 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-61 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-62 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-63 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-64 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-65 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-66 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-67 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-68 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-69 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-70 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-71 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-72 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-73 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock LayerNorm-74 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Dropout-75 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-76 [32, 145, 64] 128 128
ViTClassifierModel/TransformerBlock Linear-77 [32, 145, 128] 8,320 8,320
ViTClassifierModel/TransformerBlock GELU-78 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Dropout-79 [32, 145, 128] 0 0
ViTClassifierModel/TransformerBlock Linear-80 [32, 145, 64] 8,256 8,256
ViTClassifierModel/TransformerBlock GELU-81 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock Dropout-82 [32, 145, 64] 0 0
ViTClassifierModel/TransformerBlock LayerNorm-83 [32, 145, 64] 128 128
ViTClassifierModel Linear-84 [32, 2048] 133,120 133,120
ViTClassifierModel GELU-85 [32, 2048] 0 0
ViTClassifierModel Dropout-86 [32, 2048] 0 0
ViTClassifierModel Linear-87 [32, 1024] 2,098,176 2,098,176
ViTClassifierModel GELU-88 [32, 1024] 0 0
ViTClassifierModel Dropout-89 [32, 1024] 0 0
ViTClassifierModel Linear-90 [32, 100] 102,500 102,500
================================================================================================================
Total params: 2,485,732
Trainable params: 2,485,732
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------

Classification and Results

Training is then performed in the same way as with a regular convolutional network. We input our data, compute the loss to the labels and then calculate our accuracy. The training code for this is shown below.

def train_network(
model: torch.nn.Module,
num_epochs: int,
optimizer: torch.optim.Optimizer,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
trainloader: torch.utils.data.DataLoader,
validloader: torch.utils.data.DataLoader,
device: torch.device,
) -> None:
"""Train the Network."""
print("Training Started")
for epoch in range(1, num_epochs + 1):
sys.stdout.flush()
train_loss = []
valid_loss = []
num_examples_train = 0
num_correct_train = 0
num_examples_valid = 0
num_correct_valid = 0
num_correct_train_5 = 0
num_correct_valid_5 = 0
model.train()
for batch in trainloader:
optimizer.zero_grad()
x = batch[0].to(device)
y = batch[1].to(device)
outputs = model(x)
loss = loss_function(outputs, y)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
num_corr, num_ex = calculate_accuracy(outputs, y)
num_corr_5, _ = calculate_accuracy_top_5(outputs, y)
num_examples_train += num_ex
num_correct_train += num_corr
num_correct_train_5 += num_corr_5

model.eval()
with torch.no_grad():
for batch in validloader:
images = batch[0].to(device)
labels = batch[1].to(device)
outputs = model(images)
loss = loss_function(outputs, labels)
valid_loss.append(loss.item())
num_corr, num_ex = calculate_accuracy(outputs, labels)
num_corr_5, _ = calculate_accuracy_top_5(outputs, labels)
num_examples_valid += num_ex
num_correct_valid += num_corr
num_correct_valid_5 += num_corr_5

print(
f"Epoch: {epoch}, Training Loss: {np.mean(train_loss):.4f}, Validation Loss: {np.mean(valid_loss):.4f}, Training Accuracy: {num_correct_train/num_examples_train:.4f}, Validation Accuracy: {num_correct_valid/num_examples_valid:.4f}, Training Accuracy Top-5: {num_correct_train_5/num_examples_train:.4f}, Validation Accuracy Top-5: {num_correct_valid_5/num_examples_valid:.4f}"
)

We train for 100 epochs, using the AdamW optimizer with a learning rate of 0.001 and a weight decay of 0.0001.

As seen from the graphs below, we achieve around 50% top-1 validation accuracy and 77% top-5 validation accuracy.

Loss vs Epoch Graph
Accuracy vs Epoch Graps

Conclusions

I hope this blog has helped demystify the vision transformer architecture and implementation. As seen, it presents an alternative solution to the convolutional layer and can be applied in a relativity small network to successfully classify images.

--

--

Nathan Bailey

MSc AI and ML Student @ ICL. Ex ML Engineer @ Arm, Ex FPGA Engineer @ Arm + Intel, University of Warwick CSE Graduate, Climber. https://www.nathanbaileyw.com