Imports¶

In [ ]:
import torch
import cv2
import json
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets
from PIL import Image
In [ ]:
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_classes.json
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*kc-k_j53HOJH_sifhg4lHg.jpeg -O elephant.jpg
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbnzdczNru6HsX6qPZaXLg.jpeg -O shark.jpg
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*oRpjlGC3sUy5yQJtpwclwg.jpeg -O iguana.jpg
--2024-11-13 16:09:57--  https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.9.45, 52.217.81.230, 52.217.120.144, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.9.45|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 35363 (35K) [application/octet-stream]
Saving to: ‘imagenet_classes.json’

imagenet_classes.js 100%[===================>]  34.53K  --.-KB/s    in 0.07s   

2024-11-13 16:09:57 (496 KB/s) - ‘imagenet_classes.json’ saved [35363/35363]

Warning: wildcards not supported in HTTP.
--2024-11-13 16:09:58--  https://miro.medium.com/v2/resize:fit:640/format:webp/1*kc-k_j53HOJH_sifhg4lHg.jpeg
Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ...
Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20322 (20K) [image/webp]
Saving to: ‘elephant.jpg’

elephant.jpg        100%[===================>]  19.85K  --.-KB/s    in 0s      

2024-11-13 16:09:58 (104 MB/s) - ‘elephant.jpg’ saved [20322/20322]

Warning: wildcards not supported in HTTP.
--2024-11-13 16:09:58--  https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbnzdczNru6HsX6qPZaXLg.jpeg
Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ...
Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10886 (11K) [image/webp]
Saving to: ‘shark.jpg’

shark.jpg           100%[===================>]  10.63K  --.-KB/s    in 0s      

2024-11-13 16:09:58 (79.8 MB/s) - ‘shark.jpg’ saved [10886/10886]

Warning: wildcards not supported in HTTP.
--2024-11-13 16:09:58--  https://miro.medium.com/v2/resize:fit:640/format:webp/1*oRpjlGC3sUy5yQJtpwclwg.jpeg
Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ...
Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 64572 (63K) [image/webp]
Saving to: ‘iguana.jpg’

iguana.jpg          100%[===================>]  63.06K  --.-KB/s    in 0.04s   

2024-11-13 16:09:58 (1.74 MB/s) - ‘iguana.jpg’ saved [64572/64572]

GradCAM¶

Grad-CAM, or Gradient-weighted Class Activation Mapping, is a visualization technique that highlights the regions in an image that a CNN-based model considers important for predicting a certain class. The intuition behind Grad-CAM is that the gradients of the model's prediction with respect to the final convolutional layer's activations help in understanding which features influenced the model's decision.

In [ ]:
# Loading the pretrained VGG19 model
model = vgg19(pretrained=True)
model.eval()
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:06<00:00, 88.9MB/s]
Out[ ]:
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Preprocessing the Input Image¶

Since the VGG19 model was trained on ImageNet, we need to preprocess the input image accordingly, resizing, normalizing, and converting it into a tensor format.

In [ ]:
# Function to preprocess the image
def preprocess_image(img_path):
    # Load the image
    img = Image.open(img_path).convert('RGB')

    # Plot the image
    plt.imshow(img)
    plt.axis('off')
    plt.show()

    # Define preprocessing transformations
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Apply transformations and add batch dimension
    img_tensor = preprocess(img).unsqueeze(0)
    return img_tensor

# Load and preprocess a sample image
img_path = 'elephant.jpg'
input_image = preprocess_image(img_path)
No description has been provided for this image

Forward Hooking for Activations¶

Why are gradients important? In Grad-CAM, the gradients help identify which parts of an image contribute most to a specific decision made by the model. By calling backward() on the most probable output (logit), we propagate this decision backward through the network to understand which features were most significant.

Why do we need a "hook"? When backpropagation runs, PyTorch computes gradients only for "leaf nodes"—these are tensors that hold parameters like weights and biases. However, intermediate activations, such as those from hidden layers, do not retain their gradients after computation. Since Grad-CAM requires gradients with respect to intermediate activations (specifically the last convolutional layer), we need a way to "catch" these gradients before they’re discarded.

How do hooks work in PyTorch? PyTorch’s hook functions allow us to register callbacks that can modify or store information during forward or backward passes. A backward hook specifically catches gradients as they pass through a layer during backpropagation. For Grad-CAM, we attach a backward hook to the last convolutional layer, allowing us to capture its gradients with respect to the input image before they’re lost.

Setting up a hook on the last convolutional layer: By registering a backward hook to the final convolutional layer, we ensure that we can access and manipulate these specific gradients. The captured gradients are then averaged (pooled) across spatial locations to create "weights" for each feature map, emphasizing which aspects of the image were most influential in the model’s decision.

image.png

Using hooks is essential here because they give us an insight into the network’s interpretability layer-by-layer, allowing for real-time monitoring of gradients without modifying the core model structure.

To access the intermediate activations of the last convolutional layer, we set up a "hook" in PyTorch. Hooks let us register functions to capture gradients or activations at specific layers.

In [ ]:
# Registering hooks to capture the output and gradients
activations = None
gradients = None

def save_activation_grad(module, input, output):
    global activations
    activations = output  # Save the activations (output of the layer)

def save_gradient(module, grad_in, grad_out):
    global gradients
    gradients = grad_out[0]  # Save the gradient w.r.t the output

# Hook the last convolutional layer
layer = model.features[35]
layer.register_forward_hook(save_activation_grad)
layer.register_backward_hook(save_gradient)
Out[ ]:
<torch.utils.hooks.RemovableHandle at 0x7ea96214b610>

Forward Pass and Prediction¶

Perform a forward pass through the network to get the class predictions. This will help us identify the predicted label, which is required for calculating the gradients.

In [ ]:
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
    output = model(input_image)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top3_prob, top3_catid = torch.topk(probabilities, 3)

# Load ImageNet class labels
with open("imagenet_classes.json") as f:
    imagenet_classes = json.load(f)

# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
    print(f"{class_id}: {label} ({probability*100:.2f})")
386: ['n02504458', 'African_elephant'] (94.35)
101: ['n01871265', 'tusker'] (5.25)
385: ['n02504013', 'Indian_elephant'] (0.28)

Backward Pass for Gradients¶

We calculate the gradients of the output logit (associated with the predicted class) with respect to the activations of the chosen convolutional layer. These gradients highlight the importance of each activation channel.

In [ ]:
# Forward pass again to get the top pred class
output = model(input_image)

# Get the predicted class
predicted_class = output.argmax().item()

# Zero out any previous gradients
model.zero_grad()

# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1827: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)

Generating the Heatmap¶

To generate the heatmap, we take the average of the gradients over the spatial dimensions and weight each activation map by this average gradient. This weighted sum of activations is our Grad-CAM heatmap.

In [ ]:
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_gradients[i]

# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max()  # Normalize heatmap

plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
No description has been provided for this image

Displaying the Heatmap on the Original Image¶

The heatmap is then resized to the original image dimensions and superimposed, allowing us to visualize the regions that influenced the model’s decision.

In [ ]:
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)

# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)

# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap)  # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)

resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)

# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')

# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')

# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')

# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()
No description has been provided for this image

Testing GradCAM with VGG16 on other images¶

Shark¶

In [ ]:
# Load and preprocess a sample image
img_path = 'shark.jpg'
input_image = preprocess_image(img_path)
No description has been provided for this image
In [ ]:
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
    output = model(input_image)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top3_prob, top3_catid = torch.topk(probabilities, 3)

# Load ImageNet class labels
with open("imagenet_classes.json") as f:
    imagenet_classes = json.load(f)

# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
    print(f"{class_id}: {label} ({probability*100:.2f})")
2: ['n01484850', 'great_white_shark'] (96.19)
3: ['n01491361', 'tiger_shark'] (3.49)
4: ['n01494475', 'hammerhead'] (0.12)
In [ ]:
# Forward pass again to get the top pred class
output = model(input_image)

# Get the predicted class
predicted_class = output.argmax().item()

# Zero out any previous gradients
model.zero_grad()

# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1827: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
In [ ]:
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_gradients[i]

# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max()  # Normalize heatmap

plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
No description has been provided for this image
In [ ]:
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)

# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)

# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap)  # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)

resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)

# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')

# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')

# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')

# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()
No description has been provided for this image

Iguana¶

In [ ]:
# Load and preprocess a sample image
img_path = 'iguana.jpg'
input_image = preprocess_image(img_path)
No description has been provided for this image
In [ ]:
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
    output = model(input_image)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top3_prob, top3_catid = torch.topk(probabilities, 3)

# Load ImageNet class labels
with open("imagenet_classes.json") as f:
    imagenet_classes = json.load(f)

# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
    print(f"{class_id}: {label} ({probability*100:.2f})")
768: ['n04118538', 'rugby_ball'] (17.07)
981: ['n09835506', 'ballplayer'] (12.61)
805: ['n04254680', 'soccer_ball'] (11.12)
In [ ]:
# Forward pass again to get the top pred class
output = model(input_image)

# Get the predicted class
predicted_class = output.argmax().item()

# Zero out any previous gradients
model.zero_grad()

# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
In [ ]:
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_gradients[i]

# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max()  # Normalize heatmap

plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
No description has been provided for this image
In [ ]:
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)

# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)

# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap)  # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)

resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)

# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')

# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')

# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')

# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()
No description has been provided for this image

Iguana (cropped)¶

In [ ]:
# Load and preprocess a sample image
img_path = 'iguana.jpg'
original_img = cv2.imread(img_path)

# Get the dimensions of the image
height, width = original_img.shape[:2]

# Calculate the dimensions of each quarter
quarter_height, quarter_width = height // 2, width // 2

# Crop the top-left quarter
top_left_quarter = original_img[0:quarter_height, 0:quarter_width]

# Upsample the cropped image
scale_factor = 2
upsampled_img = cv2.resize(top_left_quarter, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)

# Save the cropped image
cropped_img_path = 'iguana_cropped.jpg'
cv2.imwrite(cropped_img_path, upsampled_img)

# Set up a grid for displaying images
fig, axs = plt.subplots(1, 2, figsize=(10, 10))

# Display non-resized heatmap in Row 1, Col 1
axs[0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[0].set_title("Original Image")
axs[0].axis('off')

# Display resized heatmap in Row 1, Col 2
axs[1].imshow(cv2.cvtColor(upsampled_img, cv2.COLOR_BGR2RGB))
axs[1].set_title("Cropped Image")
axs[1].axis('off')
Out[ ]:
(-0.5, 383.5, 511.5, -0.5)
No description has been provided for this image
In [ ]:
# Load and preprocess the image
img_path = 'iguana_cropped.jpg'
input_image = preprocess_image(img_path)
No description has been provided for this image
In [ ]:
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
    output = model(input_image)
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    top3_prob, top3_catid = torch.topk(probabilities, 3)

# Load ImageNet class labels
with open("imagenet_classes.json") as f:
    imagenet_classes = json.load(f)

# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
    print(f"{class_id}: {label} ({probability*100:.2f})")
280: ['n02120505', 'grey_fox'] (40.81)
43: ['n01688243', 'frilled_lizard'] (16.03)
39: ['n01677366', 'common_iguana'] (5.67)
In [ ]:
# Forward pass again to get the top pred class
output = model(input_image)

# Get the predicted class
predicted_class = output.argmax().item()

# Zero out any previous gradients
model.zero_grad()

# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
In [ ]:
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
    activations[:, i, :, :] *= pooled_gradients[i]

# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max()  # Normalize heatmap

plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
No description has been provided for this image
In [ ]:
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)

# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)

# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap)  # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)

resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)

# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')

# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')

# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')

# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()
No description has been provided for this image

Live demo with webcam¶

Note that this code needs to be executed locally.

Colab doesn't work well with webcam captures or cv2 windows.

In [ ]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from PIL import Image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load a pretrained model
model = models.resnet50(pretrained=True)
model.eval()


# Preprocessing function for the input image
def preprocess_image(img):
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = preprocess(img).unsqueeze(0)  # Add batch dimension
    return img_tensor


# Hook functions to get activations and gradients for Grad-CAM
activations = None
gradients = None


def save_activation_grad(module, input, output):
    global activations
    activations = output


def save_gradient(module, grad_in, grad_out):
    global gradients
    gradients = grad_out[0]


# Hook the last convolutional layer
layer = model.layer4[2].conv3  # Adjust for the last convolutional layer of your model
layer.register_forward_hook(save_activation_grad)
layer.register_backward_hook(save_gradient)


# Grad-CAM computation
def compute_gradcam(input_image, class_idx):
    # Zero out previous gradients
    model.zero_grad()

    # Perform backward pass on the predicted class
    output = model(input_image)
    output[0, class_idx].backward()

    # Pool the gradients
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # Weight the channels by gradients
    for i in range(activations.shape[1]):
        activations[:, i, :, :] *= pooled_gradients[i]

    # Generate heatmap
    heatmap = torch.mean(activations, dim=1).squeeze()
    heatmap = np.maximum(heatmap.detach().numpy(), 0)
    heatmap = heatmap / heatmap.max()  # Normalize heatmap
    return heatmap


# Initialize webcam capture
cap = cv2.VideoCapture(0)

# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Loop to continuously capture webcam frames
while True:
    # Capture frame-by-frame
    ret, frame = cap.read()

    # If no frame is captured, skip
    if not ret:
        continue

    # Convert the frame to PIL image for processing
    pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    # Preprocess the image and get the prediction
    input_image = preprocess_image(pil_img)

    with torch.no_grad():
        output = model(input_image)
        predicted_class = output.argmax().item()

    # Compute the Grad-CAM heatmap
    heatmap = compute_gradcam(input_image, predicted_class)

    # Resize the heatmap to match the frame size
    resized_heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
    resized_heatmap = np.uint8(255 * resized_heatmap)
    heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)

    # Superimpose the heatmap on the original frame
    superimposed_img = cv2.addWeighted(frame, 0.6, heatmap_colormap, 0.4, 0)

    # Close the previous windows if already open
    cv2.destroyAllWindows()

    # Display image with heatmap overlay
    cv2.imshow(f'Image with Heatmap Overlay, {predicted_class}', superimposed_img)

    # Exit condition: Close the windows if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the capture and close the window
cap.release()
cv2.destroyAllWindows()