Skip to content

Commit 4a8c4ac

Browse files
ptrblckmcarilli
authored andcommitted
Add DCGAN example (#413)
* initial commit * add default O1 mode, enable other modes, add README * add carilli's review suggestions to README
1 parent 3ef01fa commit 4a8c4ac

File tree

2 files changed

+315
-1
lines changed

2 files changed

+315
-1
lines changed

examples/dcgan/README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,41 @@
1-
Under construction...
1+
# Mixed Precision DCGAN Training in PyTorch
2+
3+
`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan).
4+
It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).
5+
6+
We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation::
7+
```
8+
# Added after models and optimizers construction
9+
[netD, netG], [optimizerD, optimizerG] = amp.initialize(
10+
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)
11+
...
12+
# loss.backward() changed to:
13+
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:
14+
errD_real_scaled.backward()
15+
...
16+
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:
17+
errD_fake_scaled.backward()
18+
...
19+
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:
20+
errG_scaled.backward()
21+
```
22+
23+
Note that we use different `loss_scalers` for each computed loss.
24+
Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss).
25+
26+
To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`.
27+
28+
With the new Amp API **you never need to explicitly convert your model, or the input data, to half().**
29+
30+
"Pure FP32" training:
31+
```
32+
$ python main_amp.py --opt-level O0
33+
```
34+
Recommended mixed precision training:
35+
```
36+
$ python main_amp.py --opt-level O1
37+
```
38+
39+
Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments.
40+
41+
To enable mixed precision training, we introduce the `--opt-level` argument.

examples/dcgan/main_amp.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
from __future__ import print_function
2+
import argparse
3+
import os
4+
import random
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.parallel
8+
import torch.backends.cudnn as cudnn
9+
import torch.optim as optim
10+
import torch.utils.data
11+
import torchvision.datasets as dset
12+
import torchvision.transforms as transforms
13+
import torchvision.utils as vutils
14+
15+
try:
16+
from apex import amp
17+
except ImportError:
18+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
19+
20+
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
23+
parser.add_argument('--dataroot', default='./', help='path to dataset')
24+
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
25+
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
26+
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
27+
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
28+
parser.add_argument('--ngf', type=int, default=64)
29+
parser.add_argument('--ndf', type=int, default=64)
30+
parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
31+
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
32+
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
33+
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
34+
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
35+
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
36+
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
37+
parser.add_argument('--manualSeed', type=int, help='manual seed')
38+
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
39+
parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"')
40+
41+
opt = parser.parse_args()
42+
print(opt)
43+
44+
45+
try:
46+
os.makedirs(opt.outf)
47+
except OSError:
48+
pass
49+
50+
if opt.manualSeed is None:
51+
opt.manualSeed = 2809
52+
print("Random Seed: ", opt.manualSeed)
53+
random.seed(opt.manualSeed)
54+
torch.manual_seed(opt.manualSeed)
55+
56+
cudnn.benchmark = True
57+
58+
59+
if opt.dataset in ['imagenet', 'folder', 'lfw']:
60+
# folder dataset
61+
dataset = dset.ImageFolder(root=opt.dataroot,
62+
transform=transforms.Compose([
63+
transforms.Resize(opt.imageSize),
64+
transforms.CenterCrop(opt.imageSize),
65+
transforms.ToTensor(),
66+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
67+
]))
68+
nc=3
69+
elif opt.dataset == 'lsun':
70+
classes = [ c + '_train' for c in opt.classes.split(',')]
71+
dataset = dset.LSUN(root=opt.dataroot, classes=classes,
72+
transform=transforms.Compose([
73+
transforms.Resize(opt.imageSize),
74+
transforms.CenterCrop(opt.imageSize),
75+
transforms.ToTensor(),
76+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
77+
]))
78+
nc=3
79+
elif opt.dataset == 'cifar10':
80+
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
81+
transform=transforms.Compose([
82+
transforms.Resize(opt.imageSize),
83+
transforms.ToTensor(),
84+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
85+
]))
86+
nc=3
87+
88+
elif opt.dataset == 'mnist':
89+
dataset = dset.MNIST(root=opt.dataroot, download=True,
90+
transform=transforms.Compose([
91+
transforms.Resize(opt.imageSize),
92+
transforms.ToTensor(),
93+
transforms.Normalize((0.5,), (0.5,)),
94+
]))
95+
nc=1
96+
97+
elif opt.dataset == 'fake':
98+
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
99+
transform=transforms.ToTensor())
100+
nc=3
101+
102+
assert dataset
103+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
104+
shuffle=True, num_workers=int(opt.workers))
105+
106+
device = torch.device("cuda:0")
107+
ngpu = int(opt.ngpu)
108+
nz = int(opt.nz)
109+
ngf = int(opt.ngf)
110+
ndf = int(opt.ndf)
111+
112+
113+
# custom weights initialization called on netG and netD
114+
def weights_init(m):
115+
classname = m.__class__.__name__
116+
if classname.find('Conv') != -1:
117+
m.weight.data.normal_(0.0, 0.02)
118+
elif classname.find('BatchNorm') != -1:
119+
m.weight.data.normal_(1.0, 0.02)
120+
m.bias.data.fill_(0)
121+
122+
123+
class Generator(nn.Module):
124+
def __init__(self, ngpu):
125+
super(Generator, self).__init__()
126+
self.ngpu = ngpu
127+
self.main = nn.Sequential(
128+
# input is Z, going into a convolution
129+
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
130+
nn.BatchNorm2d(ngf * 8),
131+
nn.ReLU(True),
132+
# state size. (ngf*8) x 4 x 4
133+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
134+
nn.BatchNorm2d(ngf * 4),
135+
nn.ReLU(True),
136+
# state size. (ngf*4) x 8 x 8
137+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
138+
nn.BatchNorm2d(ngf * 2),
139+
nn.ReLU(True),
140+
# state size. (ngf*2) x 16 x 16
141+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
142+
nn.BatchNorm2d(ngf),
143+
nn.ReLU(True),
144+
# state size. (ngf) x 32 x 32
145+
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
146+
nn.Tanh()
147+
# state size. (nc) x 64 x 64
148+
)
149+
150+
def forward(self, input):
151+
if input.is_cuda and self.ngpu > 1:
152+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
153+
else:
154+
output = self.main(input)
155+
return output
156+
157+
158+
netG = Generator(ngpu).to(device)
159+
netG.apply(weights_init)
160+
if opt.netG != '':
161+
netG.load_state_dict(torch.load(opt.netG))
162+
print(netG)
163+
164+
165+
class Discriminator(nn.Module):
166+
def __init__(self, ngpu):
167+
super(Discriminator, self).__init__()
168+
self.ngpu = ngpu
169+
self.main = nn.Sequential(
170+
# input is (nc) x 64 x 64
171+
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
172+
nn.LeakyReLU(0.2, inplace=True),
173+
# state size. (ndf) x 32 x 32
174+
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
175+
nn.BatchNorm2d(ndf * 2),
176+
nn.LeakyReLU(0.2, inplace=True),
177+
# state size. (ndf*2) x 16 x 16
178+
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
179+
nn.BatchNorm2d(ndf * 4),
180+
nn.LeakyReLU(0.2, inplace=True),
181+
# state size. (ndf*4) x 8 x 8
182+
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
183+
nn.BatchNorm2d(ndf * 8),
184+
nn.LeakyReLU(0.2, inplace=True),
185+
# state size. (ndf*8) x 4 x 4
186+
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
187+
)
188+
189+
def forward(self, input):
190+
if input.is_cuda and self.ngpu > 1:
191+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
192+
else:
193+
output = self.main(input)
194+
195+
return output.view(-1, 1).squeeze(1)
196+
197+
198+
netD = Discriminator(ngpu).to(device)
199+
netD.apply(weights_init)
200+
if opt.netD != '':
201+
netD.load_state_dict(torch.load(opt.netD))
202+
print(netD)
203+
204+
criterion = nn.BCEWithLogitsLoss()
205+
206+
fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
207+
real_label = 1
208+
fake_label = 0
209+
210+
# setup optimizer
211+
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
212+
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
213+
214+
[netD, netG], [optimizerD, optimizerG] = amp.initialize(
215+
[netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)
216+
217+
for epoch in range(opt.niter):
218+
for i, data in enumerate(dataloader, 0):
219+
############################
220+
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
221+
###########################
222+
# train with real
223+
netD.zero_grad()
224+
real_cpu = data[0].to(device)
225+
batch_size = real_cpu.size(0)
226+
label = torch.full((batch_size,), real_label, device=device)
227+
228+
output = netD(real_cpu)
229+
errD_real = criterion(output, label)
230+
with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:
231+
errD_real_scaled.backward()
232+
D_x = output.mean().item()
233+
234+
# train with fake
235+
noise = torch.randn(batch_size, nz, 1, 1, device=device)
236+
fake = netG(noise)
237+
label.fill_(fake_label)
238+
output = netD(fake.detach())
239+
errD_fake = criterion(output, label)
240+
with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:
241+
errD_fake_scaled.backward()
242+
D_G_z1 = output.mean().item()
243+
errD = errD_real + errD_fake
244+
optimizerD.step()
245+
246+
############################
247+
# (2) Update G network: maximize log(D(G(z)))
248+
###########################
249+
netG.zero_grad()
250+
label.fill_(real_label) # fake labels are real for generator cost
251+
output = netD(fake)
252+
errG = criterion(output, label)
253+
with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:
254+
errG_scaled.backward()
255+
D_G_z2 = output.mean().item()
256+
optimizerG.step()
257+
258+
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
259+
% (epoch, opt.niter, i, len(dataloader),
260+
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
261+
if i % 100 == 0:
262+
vutils.save_image(real_cpu,
263+
'%s/real_samples.png' % opt.outf,
264+
normalize=True)
265+
fake = netG(fixed_noise)
266+
vutils.save_image(fake.detach(),
267+
'%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),
268+
normalize=True)
269+
270+
# do checkpointing
271+
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
272+
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
273+
274+

0 commit comments

Comments
 (0)