Ex:8
Image augmentation using GANs
Program:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
latent_dim = 100 # Size of the latent vector for the generator input
img_size = 28 # Image size (28x28 for MNIST)
batch_size = 64
num_epochs = 100
learning_rate = 0.0002
# Create directory to save augmented images
os.makedirs("gan_augmented_images", exist_ok=True)
# Image transformations (e.g., for MNIST images)
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize images between -1 and 1
])
# Load dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Discriminator network
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_size * img_size, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0), -1)
return self.model(x)
# Generator network
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, img_size * img_size),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), 1, img_size, img_size)
return img
# Initialize generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
real_imgs = imgs.to(device)
batch_size = real_imgs.size(0) # Get actual batch size for the last incomplete batch if any
# Update real and fake labels to match batch size
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator
real_outputs = discriminator(real_imgs)
d_loss_real = criterion(real_outputs, real_labels)
# Generate fake images
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z)
fake_outputs = discriminator(fake_imgs.detach())
d_loss_fake = criterion(fake_outputs, fake_labels)
# Total discriminator loss
d_loss = d_loss_real + d_loss_fake
# Backpropagation for discriminator
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()
# Train Generator
gen_labels = torch.ones(batch_size, 1).to(device) # Generator aims for these to be classified as real
fake_outputs = discriminator(fake_imgs)
g_loss = criterion(fake_outputs, gen_labels)
# Backpropagation for generator
optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()
# Print training progress
if (i + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G
Loss: {g_loss.item()}")
# Save fake images for every epoch
fake_imgs = fake_imgs.reshape(fake_imgs.size(0), 1, img_size, img_size)
save_image(fake_imgs, f"gan_augmented_images/fake_images_epoch_{epoch+1}.png", normalize=True)
print("Training completed! Generated images saved in 'gan_augmented_images' directory.")
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import glob
# Load and display images generated in each epoch
images = sorted(glob.glob("gan_augmented_images/*.png")) # Get the saved images sorted by epoch
for img_path in images:
img = Image.open(img_path)
plt.imshow(img)
plt.title(f"Generated Image - {img_path}")
plt.axis("off")
plt.show()
# Save the trained models
torch.save(generator.state_dict(), "generator.pth")
torch.save(discriminator.state_dict(), "discriminator.pth")
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Load the saved models safely
generator.load_state_dict(torch.load("generator.pth", weights_only=True))
discriminator.load_state_dict(torch.load("discriminator.pth", weights_only=True))
Output:
<All keys matched successfully>