CS5720 - Week 9
Slide 174 of 180

Training Loops in PyTorch

Training Loop Fundamentals

A training loop is the core engine that drives the learning process in deep learning models. It orchestrates the forward pass, loss computation, backpropagation, and parameter updates.
  • 1
    Forward Pass
    Compute predictions by passing data through the model
  • 2
    Loss Calculation
    Measure how wrong the predictions are using a loss function
  • 3
    Backward Pass
    Compute gradients using automatic differentiation
  • 4
    Parameter Update
    Update model weights using the optimizer

Essential Components

📦
DataLoader
Handles batching, shuffling, and parallel data loading
Optimizer
Updates model parameters using gradients (SGD, Adam, etc.)
🎯
Loss Function
Measures prediction error (MSE, CrossEntropy, etc.)
📈
LR Scheduler
Adjusts learning rate during training
Pro Tip:
Modern PyTorch training loops often use automatic mixed precision (AMP) and gradient accumulation for better performance!

Complete Training Loop Examples

🔧 Basic Training Loop
# Basic PyTorch training loop for epoch in range(num_epochs): model.train() # Set model to training mode for batch_idx, (data, target) in enumerate(train_loader): # Move data to device (GPU/CPU) data, target = data.to(device), target.to(device) # Zero gradients from previous iteration optimizer.zero_grad() # Forward pass output = model(data) # Compute loss loss = criterion(output, target) # Backward pass loss.backward() # Update parameters optimizer.step()
🚀 Advanced Training Loop with Validation
# Advanced training loop with validation and logging def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, scheduler=None): train_losses, val_losses = [], [] for epoch in range(num_epochs): # Training phase model.train() train_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_loss += loss.item() # Print progress if batch_idx % 100 == 0: print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}') # Validation phase model.eval() val_loss = 0.0 correct = 0 with torch.no_grad(): for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) val_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() # Update learning rate if scheduler: scheduler.step() # Log epoch results train_losses.append(train_loss / len(train_loader)) val_losses.append(val_loss / len(val_loader)) print(f'Epoch {epoch}: Train Loss: {train_losses[-1]:.4f}, ' f'Val Loss: {val_losses[-1]:.4f}, ' f'Val Acc: {100 * correct / len(val_loader.dataset):.2f}%') return train_losses, val_losses
Prepared by Dr. Gorkem Kar