CS5720 - Week 5
Slide 98 of 100

CNN Visualization Techniques

Understanding what CNNs learn is crucial for building trust and improving model performance. These visualization techniques help us peek inside the "black box" of deep neural networks.
🔍
Feature Maps
Layer-by-layer visualization
Visualize the activation patterns of different layers to understand what features the network detects at each stage.
Layer Analysis
See what each convolutional layer learns
Feature Evolution
Track how features become more complex
Filter Visualization
Understand individual filter responses
Activation Maximization
Generate optimal inputs
Create synthetic images that maximally activate specific neurons or layers to understand what they're looking for.
Neuron Preferences
Discover what excites individual neurons
Class Visualization
Generate prototypical class examples
Deep Dreams
Create artistic interpretations
📍
Grad-CAM
Gradient-based localization
Use gradients to highlight which parts of an image are most important for the network's prediction.
Decision Regions
Show where the model is "looking"
Class Discrimination
Understand class-specific features
Error Analysis
Debug misclassifications
🎯
Saliency Maps
Pixel-level importance
Compute gradients with respect to input pixels to identify which pixels most influence the final prediction.
Pixel Importance
Rank pixels by their influence
Input Attribution
Trace decisions back to inputs
Adversarial Detection
Identify potential vulnerabilities

Interactive Visualization Demo

Original Image
📷
Input Image
(Click to load demo)
The original input image that we'll analyze with various visualization techniques.
Feature Maps
🔍
Layer Activations
(Processing...)
Activations from different convolutional layers showing detected features.
Model Prediction
🎯
Classification
Result: 92.5%
The model's final prediction with confidence scores for top classes.

Implementation Examples

# Feature Map Visualization import torch import torch.nn as nn import matplotlib.pyplot as plt import numpy as np class FeatureMapVisualizer: def __init__(self, model): self.model = model self.feature_maps = {} self.hooks = [] def register_hooks(self, layers_to_visualize): """Register forward hooks to capture feature maps""" def hook_fn(name): def hook(module, input, output): self.feature_maps[name] = output.detach() return hook for name, module in self.model.named_modules(): if isinstance(module, nn.Conv2d) and name in layers_to_visualize: handle = module.register_forward_hook(hook_fn(name)) self.hooks.append(handle) def visualize_feature_maps(self, input_tensor, layer_name, max_filters=16): """Visualize feature maps from a specific layer""" # Clear previous feature maps self.feature_maps.clear() # Forward pass with torch.no_grad(): _ = self.model(input_tensor) if layer_name not in self.feature_maps: print(f"Layer {layer_name} not found in captured feature maps") return feature_map = self.feature_maps[layer_name] # Select first image from batch and limit number of filters feature_map = feature_map[0] # Shape: [channels, height, width] num_filters = min(max_filters, feature_map.shape[0]) # Create subplot grid cols = 4 rows = (num_filters + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(12, 3*rows)) if rows == 1: axes = axes.reshape(1, -1) for i in range(num_filters): row = i // cols col = i % cols # Normalize feature map for visualization fmap = feature_map[i].cpu().numpy() fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min() + 1e-8) axes[row, col].imshow(fmap, cmap='viridis') axes[row, col].set_title(f'Filter {i}') axes[row, col].axis('off') # Hide unused subplots for i in range(num_filters, rows * cols): row = i // cols col = i % cols axes[row, col].axis('off') plt.suptitle(f'Feature Maps from {layer_name}') plt.tight_layout() plt.show() def compare_layers(self, input_tensor, layer_names): """Compare feature maps across different layers""" self.feature_maps.clear() with torch.no_grad(): _ = self.model(input_tensor) fig, axes = plt.subplots(1, len(layer_names), figsize=(5*len(layer_names), 5)) if len(layer_names) == 1: axes = [axes] for i, layer_name in enumerate(layer_names): if layer_name in self.feature_maps: # Take mean across all channels for overview feature_map = self.feature_maps[layer_name][0] # First image mean_activation = torch.mean(feature_map, dim=0).cpu().numpy() # Normalize mean_activation = (mean_activation - mean_activation.min()) / \ (mean_activation.max() - mean_activation.min() + 1e-8) axes[i].imshow(mean_activation, cmap='hot') axes[i].set_title(f'{layer_name}\n({feature_map.shape[0]} channels)') axes[i].axis('off') plt.suptitle('Feature Map Evolution Across Layers') plt.tight_layout() plt.show() def cleanup(self): """Remove all hooks""" for hook in self.hooks: hook.remove() self.hooks.clear() # Usage example model = torchvision.models.resnet50(pretrained=True) model.eval() visualizer = FeatureMapVisualizer(model) # Register hooks for specific layers layers_to_visualize = ['layer1.0.conv1', 'layer2.0.conv1', 'layer3.0.conv1', 'layer4.0.conv1'] visualizer.register_hooks(layers_to_visualize) # Load and preprocess image image = load_and_preprocess_image('path/to/image.jpg') # Your preprocessing function input_tensor = image.unsqueeze(0) # Add batch dimension # Visualize feature maps from different layers visualizer.visualize_feature_maps(input_tensor, 'layer1.0.conv1') visualizer.compare_layers(input_tensor, layers_to_visualize) # Cleanup visualizer.cleanup()
Prepared by Dr. Gorkem Kar