Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 46a5cee

Browse files
stu1130sandeepkrishnamurthy-dev
authored andcommitted
[MXNET-580] Add SN-GAN example (#12419)
* update sn-gan example * fix naming * add more comments * fix naming and refine comments * make power iteration as one hyperparameter * deal with divided by zero problem * replace 0.00000001 with EPSILON * refactor the example * add README * address the feedback * refine the composing * fix the typo, delete the redundant piece of code and update the result image * update folder name to align with others * update image name * add the variable back * remove the redundant piece of code and fix typo
1 parent 619bc3e commit 46a5cee

File tree

7 files changed

+423
-0
lines changed

7 files changed

+423
-0
lines changed

example/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ If your tutorial depends on specific packages, simply add them to this provision
9595
* [Gluon Examples](gluon) - several examples using the Gluon API
9696
* [Style Transfer](gluon/style_transfer) - a style transfer example using gluon
9797
* [Word Language Model](gluon/word_language_model) - an example that trains a multi-layer RNN on the Penn Treebank language modeling benchmark
98+
* [SN-GAN](gluon/sn-gan) - an example that utilizes spectral normalization to train GAN(Generative adversarial network) using Gluon API
9899
* [Image Classification with R](image-classification) - image classification on MNIST,CIFAR,ImageNet-1k,ImageNet-Full, with multiple GPU and distributed training.
99100
* [Kaggle 1st national data science bowl](kaggle-ndsb1) - a MXnet example for Kaggle Nation Data Science Bowl 1
100101
* [Kaggle 2nd national data science bowl](kaggle-ndsb2) - a tutorial for Kaggle Second Nation Data Science Bowl

example/gluon/sn_gan/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Spectral Normalization GAN
2+
3+
This example implements [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
4+
5+
## Usage
6+
7+
Example runs and the results:
8+
9+
```python
10+
python train.py --use-gpu --data-path=data
11+
```
12+
13+
* Note that the program would download the CIFAR10 for you
14+
15+
`python train.py --help` gives the following arguments:
16+
17+
```bash
18+
optional arguments:
19+
-h, --help show this help message and exit
20+
--data-path DATA_PATH
21+
path of data.
22+
--batch-size BATCH_SIZE
23+
training batch size. default is 64.
24+
--epochs EPOCHS number of training epochs. default is 100.
25+
--lr LR learning rate. default is 0.0001.
26+
--lr-beta LR_BETA learning rate for the beta in margin based loss.
27+
default is 0.5.
28+
--use-gpu use gpu for training.
29+
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default
30+
is 10.0.
31+
--z-dim Z_DIM dimension of the latent z vector. default is 100.
32+
```
33+
34+
## Result
35+
36+
![SN-GAN](sn_gan_output.png)
37+
38+
## Learned Spectral Normalization
39+
40+
![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png)
41+
42+
## Reference
43+
44+
[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow)

example/gluon/sn_gan/data.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
19+
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
20+
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
21+
22+
import numpy as np
23+
24+
import mxnet as mx
25+
from mxnet import gluon
26+
from mxnet.gluon.data.vision import CIFAR10
27+
28+
IMAGE_SIZE = 64
29+
30+
def transformer(data, label):
31+
""" data preparation """
32+
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE)
33+
data = mx.nd.transpose(data, (2, 0, 1))
34+
data = data.astype(np.float32) / 128.0 - 1
35+
return data, label
36+
37+
38+
def get_training_data(batch_size):
39+
""" helper function to get dataloader"""
40+
return gluon.data.DataLoader(
41+
CIFAR10(train=True, transform=transformer),
42+
batch_size=batch_size, shuffle=True, last_batch='discard')

example/gluon/sn_gan/model.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
19+
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
20+
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
21+
22+
import mxnet as mx
23+
from mxnet import nd
24+
from mxnet import gluon
25+
from mxnet.gluon import Block
26+
27+
28+
EPSILON = 1e-08
29+
POWER_ITERATION = 1
30+
31+
class SNConv2D(Block):
32+
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """
33+
34+
def __init__(self, num_filter, kernel_size,
35+
strides, padding, in_channels,
36+
ctx=mx.cpu(), iterations=1):
37+
38+
super(SNConv2D, self).__init__()
39+
40+
self.num_filter = num_filter
41+
self.kernel_size = kernel_size
42+
self.strides = strides
43+
self.padding = padding
44+
self.in_channels = in_channels
45+
self.iterations = iterations
46+
self.ctx = ctx
47+
48+
with self.name_scope():
49+
# init the weight
50+
self.weight = self.params.get('weight', shape=(
51+
num_filter, in_channels, kernel_size, kernel_size))
52+
self.u = self.params.get(
53+
'u', init=mx.init.Normal(), shape=(1, num_filter))
54+
55+
def _spectral_norm(self):
56+
""" spectral normalization """
57+
w = self.params.get('weight').data(self.ctx)
58+
w_mat = nd.reshape(w, [w.shape[0], -1])
59+
60+
_u = self.u.data(self.ctx)
61+
_v = None
62+
63+
for _ in range(POWER_ITERATION):
64+
_v = nd.L2Normalization(nd.dot(_u, w_mat))
65+
_u = nd.L2Normalization(nd.dot(_v, w_mat.T))
66+
67+
sigma = nd.sum(nd.dot(_u, w_mat) * _v)
68+
if sigma == 0.:
69+
sigma = EPSILON
70+
71+
self.params.setattr('u', _u)
72+
73+
return w / sigma
74+
75+
def forward(self, x):
76+
# x shape is batch_size x in_channels x height x width
77+
return nd.Convolution(
78+
data=x,
79+
weight=self._spectral_norm(),
80+
kernel=(self.kernel_size, self.kernel_size),
81+
pad=(self.padding, self.padding),
82+
stride=(self.strides, self.strides),
83+
num_filter=self.num_filter,
84+
no_bias=True
85+
)
86+
87+
88+
def get_generator():
89+
""" construct and return generator """
90+
g_net = gluon.nn.Sequential()
91+
with g_net.name_scope():
92+
93+
g_net.add(gluon.nn.Conv2DTranspose(
94+
channels=512, kernel_size=4, strides=1, padding=0, use_bias=False))
95+
g_net.add(gluon.nn.BatchNorm())
96+
g_net.add(gluon.nn.LeakyReLU(0.2))
97+
98+
g_net.add(gluon.nn.Conv2DTranspose(
99+
channels=256, kernel_size=4, strides=2, padding=1, use_bias=False))
100+
g_net.add(gluon.nn.BatchNorm())
101+
g_net.add(gluon.nn.LeakyReLU(0.2))
102+
103+
g_net.add(gluon.nn.Conv2DTranspose(
104+
channels=128, kernel_size=4, strides=2, padding=1, use_bias=False))
105+
g_net.add(gluon.nn.BatchNorm())
106+
g_net.add(gluon.nn.LeakyReLU(0.2))
107+
108+
g_net.add(gluon.nn.Conv2DTranspose(
109+
channels=64, kernel_size=4, strides=2, padding=1, use_bias=False))
110+
g_net.add(gluon.nn.BatchNorm())
111+
g_net.add(gluon.nn.LeakyReLU(0.2))
112+
113+
g_net.add(gluon.nn.Conv2DTranspose(channels=3, kernel_size=4, strides=2, padding=1, use_bias=False))
114+
g_net.add(gluon.nn.Activation('tanh'))
115+
116+
return g_net
117+
118+
119+
def get_descriptor(ctx):
120+
""" construct and return descriptor """
121+
d_net = gluon.nn.Sequential()
122+
with d_net.name_scope():
123+
124+
d_net.add(SNConv2D(num_filter=64, kernel_size=4, strides=2, padding=1, in_channels=3, ctx=ctx))
125+
d_net.add(gluon.nn.LeakyReLU(0.2))
126+
127+
d_net.add(SNConv2D(num_filter=128, kernel_size=4, strides=2, padding=1, in_channels=64, ctx=ctx))
128+
d_net.add(gluon.nn.LeakyReLU(0.2))
129+
130+
d_net.add(SNConv2D(num_filter=256, kernel_size=4, strides=2, padding=1, in_channels=128, ctx=ctx))
131+
d_net.add(gluon.nn.LeakyReLU(0.2))
132+
133+
d_net.add(SNConv2D(num_filter=512, kernel_size=4, strides=2, padding=1, in_channels=256, ctx=ctx))
134+
d_net.add(gluon.nn.LeakyReLU(0.2))
135+
136+
d_net.add(SNConv2D(num_filter=1, kernel_size=4, strides=1, padding=0, in_channels=512, ctx=ctx))
137+
138+
return d_net
395 KB
Loading

example/gluon/sn_gan/train.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
19+
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
20+
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py
21+
22+
23+
import os
24+
import random
25+
import logging
26+
import argparse
27+
28+
from data import get_training_data
29+
from model import get_generator, get_descriptor
30+
from utils import save_image
31+
32+
import mxnet as mx
33+
from mxnet import nd, autograd
34+
from mxnet import gluon
35+
36+
# CLI
37+
parser = argparse.ArgumentParser(
38+
description='train a model for Spectral Normalization GAN.')
39+
parser.add_argument('--data-path', type=str, default='./data',
40+
help='path of data.')
41+
parser.add_argument('--batch-size', type=int, default=64,
42+
help='training batch size. default is 64.')
43+
parser.add_argument('--epochs', type=int, default=100,
44+
help='number of training epochs. default is 100.')
45+
parser.add_argument('--lr', type=float, default=0.0001,
46+
help='learning rate. default is 0.0001.')
47+
parser.add_argument('--lr-beta', type=float, default=0.5,
48+
help='learning rate for the beta in margin based loss. default is 0.5.')
49+
parser.add_argument('--use-gpu', action='store_true',
50+
help='use gpu for training.')
51+
parser.add_argument('--clip_gr', type=float, default=10.0,
52+
help='Clip the gradient by projecting onto the box. default is 10.0.')
53+
parser.add_argument('--z-dim', type=int, default=10,
54+
help='dimension of the latent z vector. default is 100.')
55+
opt = parser.parse_args()
56+
57+
BATCH_SIZE = opt.batch_size
58+
Z_DIM = opt.z_dim
59+
NUM_EPOCHS = opt.epochs
60+
LEARNING_RATE = opt.lr
61+
BETA = opt.lr_beta
62+
OUTPUT_DIR = opt.data_path
63+
CTX = mx.gpu() if opt.use_gpu else mx.cpu()
64+
CLIP_GRADIENT = opt.clip_gr
65+
IMAGE_SIZE = 64
66+
67+
68+
def facc(label, pred):
69+
""" evaluate accuracy """
70+
pred = pred.ravel()
71+
label = label.ravel()
72+
return ((pred > 0.5) == label).mean()
73+
74+
75+
# setting
76+
mx.random.seed(random.randint(1, 10000))
77+
logging.basicConfig(level=logging.DEBUG)
78+
79+
# create output dir
80+
try:
81+
os.makedirs(opt.data_path)
82+
except OSError:
83+
pass
84+
85+
# get training data
86+
train_data = get_training_data(opt.batch_size)
87+
88+
# get model
89+
g_net = get_generator()
90+
d_net = get_descriptor(CTX)
91+
92+
# define loss function
93+
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
94+
95+
# initialization
96+
g_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
97+
d_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
98+
g_trainer = gluon.Trainer(
99+
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
100+
d_trainer = gluon.Trainer(
101+
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
102+
g_net.collect_params().zero_grad()
103+
d_net.collect_params().zero_grad()
104+
# define evaluation metric
105+
metric = mx.metric.CustomMetric(facc)
106+
# initialize labels
107+
real_label = nd.ones(BATCH_SIZE, CTX)
108+
fake_label = nd.zeros(BATCH_SIZE, CTX)
109+
110+
for epoch in range(NUM_EPOCHS):
111+
for i, (d, _) in enumerate(train_data):
112+
# update D
113+
data = d.as_in_context(CTX)
114+
noise = nd.normal(loc=0, scale=1, shape=(
115+
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX)
116+
with autograd.record():
117+
# train with real image
118+
output = d_net(data).reshape((-1, 1))
119+
errD_real = loss(output, real_label)
120+
metric.update([real_label, ], [output, ])
121+
122+
# train with fake image
123+
fake_image = g_net(noise)
124+
output = d_net(fake_image.detach()).reshape((-1, 1))
125+
errD_fake = loss(output, fake_label)
126+
errD = errD_real + errD_fake
127+
errD.backward()
128+
metric.update([fake_label, ], [output, ])
129+
130+
d_trainer.step(BATCH_SIZE)
131+
# update G
132+
with autograd.record():
133+
fake_image = g_net(noise)
134+
output = d_net(fake_image).reshape(-1, 1)
135+
errG = loss(output, real_label)
136+
errG.backward()
137+
138+
g_trainer.step(BATCH_SIZE)
139+
140+
# print log infomation every 100 batches
141+
if i % 100 == 0:
142+
name, acc = metric.get()
143+
logging.info('discriminator loss = %f, generator loss = %f, \
144+
binary training acc = %f at iter %d epoch %d',
145+
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch)
146+
if i == 0:
147+
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR)
148+
149+
metric.reset()

0 commit comments

Comments
 (0)