Understanding what CNNs learn is crucial for building trust and improving model performance. These visualization techniques help us peek inside the "black box" of deep neural networks.
🔍
Feature Maps
Layer-by-layer visualization
Visualize the activation patterns of different layers to understand what features the network detects at each stage.
Layer Analysis
See what each convolutional layer learns
Feature Evolution
Track how features become more complex
Filter Visualization
Understand individual filter responses
⚡
Activation Maximization
Generate optimal inputs
Create synthetic images that maximally activate specific neurons or layers to understand what they're looking for.
Neuron Preferences
Discover what excites individual neurons
Class Visualization
Generate prototypical class examples
Deep Dreams
Create artistic interpretations
📍
Grad-CAM
Gradient-based localization
Use gradients to highlight which parts of an image are most important for the network's prediction.
Decision Regions
Show where the model is "looking"
Class Discrimination
Understand class-specific features
Error Analysis
Debug misclassifications
🎯
Saliency Maps
Pixel-level importance
Compute gradients with respect to input pixels to identify which pixels most influence the final prediction.
Pixel Importance
Rank pixels by their influence
Input Attribution
Trace decisions back to inputs
Adversarial Detection
Identify potential vulnerabilities
Interactive Visualization Demo
Original Image
📷 Input Image (Click to load demo)
The original input image that we'll analyze with various visualization techniques.
Feature Maps
🔍 Layer Activations (Processing...)
Activations from different convolutional layers showing detected features.
Model Prediction
🎯 Classification Result: 92.5%
The model's final prediction with confidence scores for top classes.
Implementation Examples
# Feature Map Visualization
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
class FeatureMapVisualizer:
def __init__(self, model):
self.model = model
self.feature_maps = {}
self.hooks = []
def register_hooks(self, layers_to_visualize):
"""Register forward hooks to capture feature maps"""
def hook_fn(name):
def hook(module, input, output):
self.feature_maps[name] = output.detach()
return hook
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d) and name in layers_to_visualize:
handle = module.register_forward_hook(hook_fn(name))
self.hooks.append(handle)
def visualize_feature_maps(self, input_tensor, layer_name, max_filters=16):
"""Visualize feature maps from a specific layer"""
# Clear previous feature maps
self.feature_maps.clear()
# Forward pass
with torch.no_grad():
_ = self.model(input_tensor)
if layer_name not in self.feature_maps:
print(f"Layer {layer_name} not found in captured feature maps")
return
feature_map = self.feature_maps[layer_name]
# Select first image from batch and limit number of filters
feature_map = feature_map[0] # Shape: [channels, height, width]
num_filters = min(max_filters, feature_map.shape[0])
# Create subplot grid
cols = 4
rows = (num_filters + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(12, 3*rows))
if rows == 1:
axes = axes.reshape(1, -1)
for i in range(num_filters):
row = i // cols
col = i % cols
# Normalize feature map for visualization
fmap = feature_map[i].cpu().numpy()
fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min() + 1e-8)
axes[row, col].imshow(fmap, cmap='viridis')
axes[row, col].set_title(f'Filter {i}')
axes[row, col].axis('off')
# Hide unused subplots
for i in range(num_filters, rows * cols):
row = i // cols
col = i % cols
axes[row, col].axis('off')
plt.suptitle(f'Feature Maps from {layer_name}')
plt.tight_layout()
plt.show()
def compare_layers(self, input_tensor, layer_names):
"""Compare feature maps across different layers"""
self.feature_maps.clear()
with torch.no_grad():
_ = self.model(input_tensor)
fig, axes = plt.subplots(1, len(layer_names), figsize=(5*len(layer_names), 5))
if len(layer_names) == 1:
axes = [axes]
for i, layer_name in enumerate(layer_names):
if layer_name in self.feature_maps:
# Take mean across all channels for overview
feature_map = self.feature_maps[layer_name][0] # First image
mean_activation = torch.mean(feature_map, dim=0).cpu().numpy()
# Normalize
mean_activation = (mean_activation - mean_activation.min()) / \
(mean_activation.max() - mean_activation.min() + 1e-8)
axes[i].imshow(mean_activation, cmap='hot')
axes[i].set_title(f'{layer_name}\n({feature_map.shape[0]} channels)')
axes[i].axis('off')
plt.suptitle('Feature Map Evolution Across Layers')
plt.tight_layout()
plt.show()
def cleanup(self):
"""Remove all hooks"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
# Usage example
model = torchvision.models.resnet50(pretrained=True)
model.eval()
visualizer = FeatureMapVisualizer(model)
# Register hooks for specific layers
layers_to_visualize = ['layer1.0.conv1', 'layer2.0.conv1', 'layer3.0.conv1', 'layer4.0.conv1']
visualizer.register_hooks(layers_to_visualize)
# Load and preprocess image
image = load_and_preprocess_image('path/to/image.jpg') # Your preprocessing function
input_tensor = image.unsqueeze(0) # Add batch dimension
# Visualize feature maps from different layers
visualizer.visualize_feature_maps(input_tensor, 'layer1.0.conv1')
visualizer.compare_layers(input_tensor, layers_to_visualize)
# Cleanup
visualizer.cleanup()