import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Feature extraction layers
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
# Pooling layer
self.pool = nn.MaxPool2d(2, 2)
# Batch normalization
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
# Classifier layers
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(128 * 4 * 4, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
# Feature extraction
x = self.pool(F.relu(self.bn1(self.conv1(x))))
x = self.pool(F.relu(self.bn2(self.conv2(x))))
x = self.pool(F.relu(self.bn3(self.conv3(x))))
# Flatten for classifier
x = x.view(x.size(0), -1)
# Classification
x = self.classifier(x)
return x
# Usage
model = SimpleCNN(num_classes=10)
print(model)
import torchvision.models as models
import torch.nn as nn
class TransferLearningModel(nn.Module):
def __init__(self, num_classes, model_name='resnet18'):
super(TransferLearningModel, self).__init__()
# Load pre-trained model
if model_name == 'resnet18':
self.backbone = models.resnet18(pretrained=True)
num_features = self.backbone.fc.in_features
# Replace final layer
self.backbone.fc = nn.Linear(num_features, num_classes)
elif model_name == 'vgg16':
self.backbone = models.vgg16(pretrained=True)
# Modify classifier
self.backbone.classifier[6] = nn.Linear(4096, num_classes)
elif model_name == 'efficientnet':
self.backbone = models.efficientnet_b0(pretrained=True)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier[1] = nn.Linear(num_features, num_classes)
def forward(self, x):
return self.backbone(x)
def freeze_backbone(self):
"""Freeze backbone parameters for feature extraction"""
for param in self.backbone.parameters():
param.requires_grad = False
# Unfreeze classifier
if hasattr(self.backbone, 'fc'):
for param in self.backbone.fc.parameters():
param.requires_grad = True
elif hasattr(self.backbone, 'classifier'):
for param in self.backbone.classifier.parameters():
param.requires_grad = True
def unfreeze_backbone(self):
"""Unfreeze all parameters for fine-tuning"""
for param in self.backbone.parameters():
param.requires_grad = True
# Usage
model = TransferLearningModel(num_classes=5, model_name='resnet18')
model.freeze_backbone() # Feature extraction mode
# Later: fine-tune entire model
model.unfreeze_backbone()