The attention mechanism that makes transformers so powerful for text also works remarkably well for images. The Vision Transformer (ViT), introduced by Google in 2020, applies the same self-attention architecture to image classification by splitting an image into a grid of patches and treating each patch like a “word” in a sentence.
This notebook demonstrates how attention extends from text to images. We load a pre-trained ViT model, feed it an image, and then visualize the attention heatmap to see which parts of the image the model focused on when making its prediction. The heatmap reveals that the model learns to attend to semantically meaningful regions -- the animal’s face, distinctive features, or key objects -- rather than processing the entire image uniformly.
Understanding attention visualization is valuable for model interpretability: it lets you verify that the model is making predictions for the right reasons, not just exploiting spurious correlations in the background.
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt
import requests
import numpy as np
Loading the Vision Transformer and Required Libraries¶
We import PyTorch, the Hugging Face ViT model, and visualization libraries. The Vision Transformer (ViT) processes images by dividing them into 16x16 pixel patches, embedding each patch into a vector, and then running standard transformer self-attention over these patch embeddings. This is conceptually identical to how a text transformer processes word tokens -- the key insight that bridges NLP and computer vision.
# 1. Load the Model and Processor
model_name = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, output_attentions=True)
# 2. Load an Image (Using a clear cat image)
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg' # Two cats on a remote
#url = 'https://pawpals.ae/wp-content/uploads/2023/03/Shorthair-cat-1-1024x679.jpg'
#url = 'https://www.alleycat.org/wp-content/uploads/2018/08/SocialMediaLink_1200x628_GenericCat-Centered.png'
url = 'https://d3544la1u8djza.cloudfront.net/APHI/Blog/2023/September/small-breeds-hero.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# 3. Process and Predict
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Find the label with the highest score
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print(f"Predicted Class: {model.config.id2label[predicted_class_idx]}")
# 4. Extract Attention for that specific prediction
# We take the last layer's attention (Layer 11)
attentions = outputs.attentions[-1]
# Average the attention weights across all 'heads'
# We focus on the [CLS] token (index 0) which is the 'Decision Maker'
# The [CLS] token's attention toward the 196 image patches (index 1 to 197)
nh = attentions.shape[1] # Number of heads
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
val_attn = attentions.mean(dim=0) # Average across heads
# 5. Reshape and Resize Heatmap
grid_size = int(np.sqrt(val_attn.shape[0])) # 14x14
heatmap = val_attn.reshape(grid_size, grid_size).detach().numpy()
# 6. Plotting
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(image)
ax[0].set_title("Input Image")
ax[0].axis('off')
# Overlay the heatmap
ax[1].imshow(image)
ax[1].imshow(heatmap, cmap='magma', alpha=0.7, extent=(0, image.size[0], image.size[1], 0))
ax[1].set_title(f"Attention Spotlight for: {model.config.id2label[predicted_class_idx]}")
ax[1].axis('off')
plt.show()Processing an Image and Extracting Attention Maps¶
This cell does the heavy lifting: we load the pre-trained ViT model (with output_attentions=True to capture the attention weights), download a test image, and run inference. After getting the predicted class, we extract the attention weights from the last transformer layer and focus on the [CLS] token -- this is the special classification token whose attention pattern tells us which image patches were most important for the final prediction.
The resulting 14x14 heatmap is overlaid on the original image, creating an intuitive visualization. Bright regions in the heatmap indicate patches that received high attention -- these are the areas the model considered most informative for its classification decision. This kind of visualization is a powerful tool for debugging and explaining model behavior to non-technical stakeholders.
Key takeaways¶
Vision Transformers apply the same self-attention mechanism used on text to images, by splitting an image into 16x16 patches treated like word tokens.
Attention heatmaps reveal which image regions the model focused on, turning a black-box prediction into a visual explanation.
The [CLS] token aggregates information across patches, and its attention weights toward each patch drive the classification decision.
Pre-trained ViT from Hugging Face (
google/vit-base-patch16-224) can be loaded withoutput_attentions=Trueto expose these weights for analysis.Interpretability matters in production: visualizing attention helps verify a model is relying on meaningful features rather than spurious background cues.
Run the code¶
To run this notebook, copy the URL below into your browser’s address bar. The link opens the notebook directly in Google Colab. (If your PDF viewer makes the URL clickable and lands on a broken page, copy the full text manually -- the viewer may have truncated the link at a line break.)
Estimated run time: ~3 minutes on T4 GPU
https://colab.research.google.com/github/KarAnalytics/code_demos/blob/main/Attention_in_Image_Analysis.ipynb