Imports¶
import torch
import cv2
import json
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils import data
from torchvision.models import vgg19
from torchvision import transforms
from torchvision import datasets
from PIL import Image
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_classes.json
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*kc-k_j53HOJH_sifhg4lHg.jpeg -O elephant.jpg
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbnzdczNru6HsX6qPZaXLg.jpeg -O shark.jpg
!wget https://miro.medium.com/v2/resize:fit:640/format:webp/1*oRpjlGC3sUy5yQJtpwclwg.jpeg -O iguana.jpg
--2024-11-13 16:09:57-- https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.9.45, 52.217.81.230, 52.217.120.144, ... Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.9.45|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 35363 (35K) [application/octet-stream] Saving to: ‘imagenet_classes.json’ imagenet_classes.js 100%[===================>] 34.53K --.-KB/s in 0.07s 2024-11-13 16:09:57 (496 KB/s) - ‘imagenet_classes.json’ saved [35363/35363] Warning: wildcards not supported in HTTP. --2024-11-13 16:09:58-- https://miro.medium.com/v2/resize:fit:640/format:webp/1*kc-k_j53HOJH_sifhg4lHg.jpeg Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ... Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 20322 (20K) [image/webp] Saving to: ‘elephant.jpg’ elephant.jpg 100%[===================>] 19.85K --.-KB/s in 0s 2024-11-13 16:09:58 (104 MB/s) - ‘elephant.jpg’ saved [20322/20322] Warning: wildcards not supported in HTTP. --2024-11-13 16:09:58-- https://miro.medium.com/v2/resize:fit:640/format:webp/1*XbnzdczNru6HsX6qPZaXLg.jpeg Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ... Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 10886 (11K) [image/webp] Saving to: ‘shark.jpg’ shark.jpg 100%[===================>] 10.63K --.-KB/s in 0s 2024-11-13 16:09:58 (79.8 MB/s) - ‘shark.jpg’ saved [10886/10886] Warning: wildcards not supported in HTTP. --2024-11-13 16:09:58-- https://miro.medium.com/v2/resize:fit:640/format:webp/1*oRpjlGC3sUy5yQJtpwclwg.jpeg Resolving miro.medium.com (miro.medium.com)... 162.159.153.4, 162.159.152.4, 2606:4700:7::a29f:9904, ... Connecting to miro.medium.com (miro.medium.com)|162.159.153.4|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 64572 (63K) [image/webp] Saving to: ‘iguana.jpg’ iguana.jpg 100%[===================>] 63.06K --.-KB/s in 0.04s 2024-11-13 16:09:58 (1.74 MB/s) - ‘iguana.jpg’ saved [64572/64572]
GradCAM¶
Grad-CAM, or Gradient-weighted Class Activation Mapping, is a visualization technique that highlights the regions in an image that a CNN-based model considers important for predicting a certain class. The intuition behind Grad-CAM is that the gradients of the model's prediction with respect to the final convolutional layer's activations help in understanding which features influenced the model's decision.
# Loading the pretrained VGG19 model
model = vgg19(pretrained=True)
model.eval()
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth 100%|██████████| 548M/548M [00:06<00:00, 88.9MB/s]
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
) Preprocessing the Input Image¶
Since the VGG19 model was trained on ImageNet, we need to preprocess the input image accordingly, resizing, normalizing, and converting it into a tensor format.
# Function to preprocess the image
def preprocess_image(img_path):
# Load the image
img = Image.open(img_path).convert('RGB')
# Plot the image
plt.imshow(img)
plt.axis('off')
plt.show()
# Define preprocessing transformations
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Apply transformations and add batch dimension
img_tensor = preprocess(img).unsqueeze(0)
return img_tensor
# Load and preprocess a sample image
img_path = 'elephant.jpg'
input_image = preprocess_image(img_path)
Forward Hooking for Activations¶
Why are gradients important? In Grad-CAM, the gradients help identify which parts of an image contribute most to a specific decision made by the model. By calling backward() on the most probable output (logit), we propagate this decision backward through the network to understand which features were most significant.
Why do we need a "hook"? When backpropagation runs, PyTorch computes gradients only for "leaf nodes"—these are tensors that hold parameters like weights and biases. However, intermediate activations, such as those from hidden layers, do not retain their gradients after computation. Since Grad-CAM requires gradients with respect to intermediate activations (specifically the last convolutional layer), we need a way to "catch" these gradients before they’re discarded.
How do hooks work in PyTorch? PyTorch’s hook functions allow us to register callbacks that can modify or store information during forward or backward passes. A backward hook specifically catches gradients as they pass through a layer during backpropagation. For Grad-CAM, we attach a backward hook to the last convolutional layer, allowing us to capture its gradients with respect to the input image before they’re lost.
Setting up a hook on the last convolutional layer: By registering a backward hook to the final convolutional layer, we ensure that we can access and manipulate these specific gradients. The captured gradients are then averaged (pooled) across spatial locations to create "weights" for each feature map, emphasizing which aspects of the image were most influential in the model’s decision.
Using hooks is essential here because they give us an insight into the network’s interpretability layer-by-layer, allowing for real-time monitoring of gradients without modifying the core model structure.
To access the intermediate activations of the last convolutional layer, we set up a "hook" in PyTorch. Hooks let us register functions to capture gradients or activations at specific layers.
# Registering hooks to capture the output and gradients
activations = None
gradients = None
def save_activation_grad(module, input, output):
global activations
activations = output # Save the activations (output of the layer)
def save_gradient(module, grad_in, grad_out):
global gradients
gradients = grad_out[0] # Save the gradient w.r.t the output
# Hook the last convolutional layer
layer = model.features[35]
layer.register_forward_hook(save_activation_grad)
layer.register_backward_hook(save_gradient)
<torch.utils.hooks.RemovableHandle at 0x7ea96214b610>
Forward Pass and Prediction¶
Perform a forward pass through the network to get the class predictions. This will help us identify the predicted label, which is required for calculating the gradients.
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
output = model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top3_prob, top3_catid = torch.topk(probabilities, 3)
# Load ImageNet class labels
with open("imagenet_classes.json") as f:
imagenet_classes = json.load(f)
# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
print(f"{class_id}: {label} ({probability*100:.2f})")
386: ['n02504458', 'African_elephant'] (94.35) 101: ['n01871265', 'tusker'] (5.25) 385: ['n02504013', 'Indian_elephant'] (0.28)
Backward Pass for Gradients¶
We calculate the gradients of the output logit (associated with the predicted class) with respect to the activations of the chosen convolutional layer. These gradients highlight the importance of each activation channel.
# Forward pass again to get the top pred class
output = model(input_image)
# Get the predicted class
predicted_class = output.argmax().item()
# Zero out any previous gradients
model.zero_grad()
# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1827: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Generating the Heatmap¶
To generate the heatmap, we take the average of the gradients over the spatial dimensions and weight each activation map by this average gradient. This weighted sum of activations is our Grad-CAM heatmap.
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_gradients[i]
# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max() # Normalize heatmap
plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
Displaying the Heatmap on the Original Image¶
The heatmap is then resized to the original image dimensions and superimposed, allowing us to visualize the regions that influenced the model’s decision.
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)
# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)
# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap) # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)
resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)
# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')
# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')
# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')
# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
Testing GradCAM with VGG16 on other images¶
Shark¶
# Load and preprocess a sample image
img_path = 'shark.jpg'
input_image = preprocess_image(img_path)
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
output = model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top3_prob, top3_catid = torch.topk(probabilities, 3)
# Load ImageNet class labels
with open("imagenet_classes.json") as f:
imagenet_classes = json.load(f)
# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
print(f"{class_id}: {label} ({probability*100:.2f})")
2: ['n01484850', 'great_white_shark'] (96.19) 3: ['n01491361', 'tiger_shark'] (3.49) 4: ['n01494475', 'hammerhead'] (0.12)
# Forward pass again to get the top pred class
output = model(input_image)
# Get the predicted class
predicted_class = output.argmax().item()
# Zero out any previous gradients
model.zero_grad()
# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1827: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_gradients[i]
# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max() # Normalize heatmap
plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)
# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)
# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap) # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)
resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)
# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')
# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')
# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')
# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
Iguana¶
# Load and preprocess a sample image
img_path = 'iguana.jpg'
input_image = preprocess_image(img_path)
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
output = model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top3_prob, top3_catid = torch.topk(probabilities, 3)
# Load ImageNet class labels
with open("imagenet_classes.json") as f:
imagenet_classes = json.load(f)
# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
print(f"{class_id}: {label} ({probability*100:.2f})")
768: ['n04118538', 'rugby_ball'] (17.07) 981: ['n09835506', 'ballplayer'] (12.61) 805: ['n04254680', 'soccer_ball'] (11.12)
# Forward pass again to get the top pred class
output = model(input_image)
# Get the predicted class
predicted_class = output.argmax().item()
# Zero out any previous gradients
model.zero_grad()
# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_gradients[i]
# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max() # Normalize heatmap
plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)
# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)
# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap) # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)
resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)
# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')
# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')
# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')
# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
Iguana (cropped)¶
# Load and preprocess a sample image
img_path = 'iguana.jpg'
original_img = cv2.imread(img_path)
# Get the dimensions of the image
height, width = original_img.shape[:2]
# Calculate the dimensions of each quarter
quarter_height, quarter_width = height // 2, width // 2
# Crop the top-left quarter
top_left_quarter = original_img[0:quarter_height, 0:quarter_width]
# Upsample the cropped image
scale_factor = 2
upsampled_img = cv2.resize(top_left_quarter, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
# Save the cropped image
cropped_img_path = 'iguana_cropped.jpg'
cv2.imwrite(cropped_img_path, upsampled_img)
# Set up a grid for displaying images
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
# Display non-resized heatmap in Row 1, Col 1
axs[0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[0].set_title("Original Image")
axs[0].axis('off')
# Display resized heatmap in Row 1, Col 2
axs[1].imshow(cv2.cvtColor(upsampled_img, cv2.COLOR_BGR2RGB))
axs[1].set_title("Cropped Image")
axs[1].axis('off')
(-0.5, 383.5, 511.5, -0.5)
# Load and preprocess the image
img_path = 'iguana_cropped.jpg'
input_image = preprocess_image(img_path)
# Perform a forward pass and get the top 3 predictions
with torch.no_grad():
output = model(input_image)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top3_prob, top3_catid = torch.topk(probabilities, 3)
# Load ImageNet class labels
with open("imagenet_classes.json") as f:
imagenet_classes = json.load(f)
# Display results
results = [(imagenet_classes[str(catid.item())], top3_prob[idx].item()) for idx, catid in enumerate(top3_catid)]
for class_id, (label, probability) in zip(top3_catid, results):
print(f"{class_id}: {label} ({probability*100:.2f})")
280: ['n02120505', 'grey_fox'] (40.81) 43: ['n01688243', 'frilled_lizard'] (16.03) 39: ['n01677366', 'common_iguana'] (5.67)
# Forward pass again to get the top pred class
output = model(input_image)
# Get the predicted class
predicted_class = output.argmax().item()
# Zero out any previous gradients
model.zero_grad()
# Perform backward pass on the output of the predicted class
output[0, predicted_class].backward()
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Weight the channels by corresponding gradients
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_gradients[i]
# Generate the heatmap by averaging across channels and applying ReLU
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max() # Normalize heatmap
plt.imshow(heatmap, cmap='cool')
plt.axis('off')
plt.show()
# Load the image and convert to NumPy format
original_img = cv2.imread(img_path)
# Resize the heatmap to the original image size
resized_heatmap = cv2.resize(heatmap, (original_img.shape[1], original_img.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)
# Convert non-resized and resized heatmaps to color maps
non_resized_heatmap = np.uint8(255 * heatmap) # Non-resized heatmap
non_resized_heatmap_colormap = cv2.applyColorMap(non_resized_heatmap, cv2.COLORMAP_JET)
resized_heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
# Superimpose resized heatmap on original image
superimposed_img = cv2.addWeighted(original_img, 0.6, resized_heatmap_colormap, 0.4, 0)
# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
# Display non-resized heatmap in Row 1, Col 1
axs[0, 0].imshow(cv2.cvtColor(non_resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 0].set_title("Non-resized Heatmap")
axs[0, 0].axis('off')
# Display resized heatmap in Row 1, Col 2
axs[0, 1].imshow(cv2.cvtColor(resized_heatmap_colormap, cv2.COLOR_BGR2RGB))
axs[0, 1].set_title("Resized Heatmap")
axs[0, 1].axis('off')
# Display original image in Row 2, Col 1
axs[1, 0].imshow(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB))
axs[1, 0].set_title("Original Image")
axs[1, 0].axis('off')
# Display image with heatmap overlay in Row 2, Col 2
axs[1, 1].imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
axs[1, 1].set_title("Image with Heatmap Overlay")
axs[1, 1].axis('off')
# Adjust layout and show the plot
plt.tight_layout()
plt.show()
Live demo with webcam¶
Note that this code needs to be executed locally.
Colab doesn't work well with webcam captures or cv2 windows.
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from PIL import Image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load a pretrained model
model = models.resnet50(pretrained=True)
model.eval()
# Preprocessing function for the input image
def preprocess_image(img):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension
return img_tensor
# Hook functions to get activations and gradients for Grad-CAM
activations = None
gradients = None
def save_activation_grad(module, input, output):
global activations
activations = output
def save_gradient(module, grad_in, grad_out):
global gradients
gradients = grad_out[0]
# Hook the last convolutional layer
layer = model.layer4[2].conv3 # Adjust for the last convolutional layer of your model
layer.register_forward_hook(save_activation_grad)
layer.register_backward_hook(save_gradient)
# Grad-CAM computation
def compute_gradcam(input_image, class_idx):
# Zero out previous gradients
model.zero_grad()
# Perform backward pass on the predicted class
output = model(input_image)
output[0, class_idx].backward()
# Pool the gradients
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
# Weight the channels by gradients
for i in range(activations.shape[1]):
activations[:, i, :, :] *= pooled_gradients[i]
# Generate heatmap
heatmap = torch.mean(activations, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().numpy(), 0)
heatmap = heatmap / heatmap.max() # Normalize heatmap
return heatmap
# Initialize webcam capture
cap = cv2.VideoCapture(0)
# Set up a 2x2 grid for displaying images
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
# Loop to continuously capture webcam frames
while True:
# Capture frame-by-frame
ret, frame = cap.read()
# If no frame is captured, skip
if not ret:
continue
# Convert the frame to PIL image for processing
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Preprocess the image and get the prediction
input_image = preprocess_image(pil_img)
with torch.no_grad():
output = model(input_image)
predicted_class = output.argmax().item()
# Compute the Grad-CAM heatmap
heatmap = compute_gradcam(input_image, predicted_class)
# Resize the heatmap to match the frame size
resized_heatmap = cv2.resize(heatmap, (frame.shape[1], frame.shape[0]))
resized_heatmap = np.uint8(255 * resized_heatmap)
heatmap_colormap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
# Superimpose the heatmap on the original frame
superimposed_img = cv2.addWeighted(frame, 0.6, heatmap_colormap, 0.4, 0)
# Close the previous windows if already open
cv2.destroyAllWindows()
# Display image with heatmap overlay
cv2.imshow(f'Image with Heatmap Overlay, {predicted_class}', superimposed_img)
# Exit condition: Close the windows if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# Release the capture and close the window
cap.release()
cv2.destroyAllWindows()