Complete Implementation Example
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 1. Data Preparation
def get_data_loaders(batch_size=32):
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
transform_val = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Assuming you have custom dataset or using ImageFolder
train_dataset = torchvision.datasets.ImageFolder(
'path/to/train', transform=transform_train)
val_dataset = torchvision.datasets.ImageFolder(
'path/to/val', transform=transform_val)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=4, pin_memory=True)
return train_loader, val_loader, len(train_dataset.classes)
# 2. Model Setup
def create_model(num_classes, pretrained=True):
model = torchvision.models.resnet50(pretrained=pretrained)
# Freeze early layers (optional)
for param in list(model.parameters())[:-10]:
param.requires_grad = False
# Replace classifier
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
# 3. Training Function
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, target)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# Print progress
if batch_idx % 100 == 0:
print(f'Batch {batch_idx}/{len(train_loader)}, '
f'Loss: {loss.item():.4f}, '
f'Acc: {100.*correct/total:.2f}%')
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# 4. Validation Function
def validate(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
outputs = model(data)
loss = criterion(outputs, target)
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
epoch_loss = running_loss / len(val_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# 5. Main Training Loop
def train_cnn():
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, num_classes = get_data_loaders(batch_size=32)
model = create_model(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# Training loop
num_epochs = 50
best_val_acc = 0.0
train_losses, val_losses = [], []
train_accs, val_accs = [], []
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 30)
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion,
optimizer, device)
# Validate
val_loss, val_acc = validate(model, val_loader, criterion, device)
# Learning rate scheduling
scheduler.step()
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
}, 'best_model.pth')
# Log metrics
train_losses.append(train_loss)
val_losses.append(val_loss)
train_accs.append(train_acc)
val_accs.append(val_acc)
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
print(f'Best Val Acc: {best_val_acc:.2f}%')
print()
return model, train_losses, val_losses, train_accs, val_accs
# 6. Run Training
if __name__ == "__main__":
model, train_losses, val_losses, train_accs, val_accs = train_cnn()
# Plot training curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.tight_layout()
plt.show()
🎯 Best Practices Checklist
Data Handling
Use appropriate transforms, handle class imbalance, validate data quality
Model Architecture
Start simple, use proven architectures, consider computational constraints
Training Strategy
Monitor overfitting, use validation sets, implement early stopping
Hyperparameters
Tune learning rate, batch size, regularization systematically
Debugging
Check gradients, visualize features, start with small datasets
Reproducibility
Set random seeds, version control, document experiments