import torch
import torch.nn as nn
class Autoencoder(nn.Module):
def __init__(self, input_dim=784, encoding_dim=32):
super(Autoencoder, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, encoding_dim)
)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(encoding_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
nn.Sigmoid() # For normalized outputs
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
# Create model
model = Autoencoder(input_dim=784, encoding_dim=32)
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")