SRGAN
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
# Residual-in-Residual Dense Block for the ESRGAN Generator
class ResidualDenseBlock(nn.Module):
def __init__(self, in_channels=64, growth_channels=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, growth_channels, 3, 1, 1)
self.conv2 = nn.Conv2d(in_channels + growth_channels, growth_channels, 3, 1, 1)
self.conv3 = nn.Conv2d(in_channels + 2 * growth_channels, growth_channels, 3, 1, 1)
self.conv4 = nn.Conv2d(in_channels + 3 * growth_channels, growth_channels, 3, 1, 1)
self.conv5 = nn.Conv2d(in_channels + 4 * growth_channels, in_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
out1 = self.lrelu(self.conv1(x))
out2 = self.lrelu(self.conv2(torch.cat((x, out1), 1)))
out3 = self.lrelu(self.conv3(torch.cat((x, out1, out2), 1)))
out4 = self.lrelu(self.conv4(torch.cat((x, out1, out2, out3), 1)))
out5 = self.conv5(torch.cat((x, out1, out2, out3, out4), 1))
return out5 * 0.2 + x
# Full ESRGAN Generator with Upsampling layers
class ESRGAN_Generator(nn.Module):
def __init__(self, num_blocks=23):
super(ESRGAN_Generator, self).__init__()
self.conv_first = nn.Conv2d(3, 64, 3, 1, 1)
# Residual-in-Residual Dense Blocks
self.residual_blocks = nn.Sequential(
*[ResidualDenseBlock() for _ in range(num_blocks)]
)
self.conv_last = nn.Conv2d(64, 64, 3, 1, 1)
# Upsampling layers
self.upsample1 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2), # Upsample by factor of 2
nn.LeakyReLU(0.2, inplace=True)
)
self.upsample2 = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2), # Upsample by factor of 2
nn.LeakyReLU(0.2, inplace=True)
)
self.conv_hr = nn.Conv2d(64, 3, 3, 1, 1) # Output 3 channels (RGB)
def forward(self, x):
out1 = self.conv_first(x)
residual = self.residual_blocks(out1)
out2 = self.conv_last(residual)
out = torch.add(out1, out2)
out = self.upsample1(out)
out = self.upsample2(out)
out = self.conv_hr(out)
return out
# Function to preprocess the image (convert to tensor)
def preprocess_image(image_path):
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = transforms.ToTensor()(img).unsqueeze(0) # Add batch dimension
return img
# Function to perform super-resolution
def super_resolve(image_path, model):
img = preprocess_image(image_path)
img = img.to('cuda' if torch.cuda.is_available() else 'cpu')
# Super-resolve image
with torch.no_grad():
sr_img = model(img)
sr_img = sr_img.squeeze(0).cpu().numpy() # Remove batch dimension and move to CPU
sr_img = np.clip(sr_img, 0, 1) # Ensure values are between 0 and 1
sr_img = sr_img.transpose(1, 2, 0) # Convert to HWC format
return sr_img
# Display the input low-resolution and the super-resolved high-resolution images
def display_images(lr_image_path, sr_image):
lr_img = cv2.imread(lr_image_path)
lr_img = cv2.cvtColor(lr_img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(12, 6))
# Low-resolution image
plt.subplot(1, 2, 1)
plt.title('Low Resolution')
plt.imshow(lr_img)
# Super-resolved image
plt.subplot(1, 2, 2)
plt.title('Super-Resolved')
plt.imshow(sr_image)
plt.show()
# Initialize the model
generator = ESRGAN_Generator().to('cuda' if torch.cuda.is_available() else 'cpu')
# Super-resolve the image
low_res_image = "dog_image.jpg" # Replace this with the path to your low-res dog image
sr_img = super_resolve(low_res_image, generator)
# Display the result
display_images(low_res_image, sr_img)
import torch
from torch import nn
class ConvBLock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
discrininator=False,
use_act=True,
use_bn=True,
**kwargs,
):
super ().__init__()
self.cnn = nn.Conv2d(in_channels,out_channels, **kwargs, bias=not
use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else
nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discrininator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else
self.bn(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_c, scale_factor):
super().__init__()
self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1 , 1)
self.ps = nn.PixelShuffle(scale_factor) # in_c * 4, H, I -=>
in,c, H×2, W*2
self.act = nn. PReLU(nun_paraneters=in_c)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.block1 = ConvBlock(
in_channels,
in_channels,
kerneL_size=3,
stride=1,
padding=1
)
self.bLock2 = ConvBlock(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False,
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
class Generator (nn.Module):
def __Init__ (self, in_channels=3, num_channels=64, nun_blocks=16):
super().__init__()
self.initial = ConvBlock(in_channels, nun_channels,
kernel_size=9, stride=1, padding=4, use_bn=False)
self.residuals = nn.Sequential(*[ResidualBlock(nun_channels) for
_ in range(num_blocks)])
self.convblock = ConvBlock(num_channels, nun_channels,
kernel_size=3, stride=1, padding=1, use_acte=False)
self.upsamples = nn.Sequential (UpsampleBlock(num_channels, 2),
UpsampleBlock(nun_channels, 2))
self.final = nn.Conv2d(nun_channels, in_channels, kernel_size=9,
stride=1, padding=4)
def forward(self, x):
initial = self.initial(x)
x = self.residuals(initial)
x = self.convblock(x) + initial
x = self.upsanples(x)
return torch. tanh(self.final(x))
class Discriminator (nn.Module) :
def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256,
256,512, 512]):
blocks = []
for idx, feature in enunerate(features):
blocks.append(
ConvBlock(
in_channels,
feature,
kerneL_size=3,
stride=1 + idx % 2,
padding=1,
discrininstor=True,
use_act=True,
use_bn=False if idx == 0 else True,
)
)
in_channels = feature
self.blocks=nn.Sequential(*blocks)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((6, 6)),
nn. Flatten(),
nn. Linear (512*6*6, 1024),
nn. LeakyReLU(0.2, inplace=True),
nn. Linear (1024, 1),
)
def forward(self, x):
x = self. blocks(x)
return self.classifier(x)
def test():
low_resolution = 24 # 96x96 → 24x24 with torch.cuda.amp.autocast):
with torch.cuda.amp.autocast():
x = torch.randn((5, 3, low_resolution, low_resolution))
gen = Generator()
gen_out = gen(x)
disc = Discriminator()
disc_out = disc(gen_out)
print (gen_out.shape)
print(disc_out.shape)