-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
I fail to calculate high order gradient of nn.Module with Non-linear Activations
def calc_gradient_penalty(netD, real_data, fake_data):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if use_cuda:
interpolates = interpolates.cuda()
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(disc_interpolates.size()),
create_graph=True, only_inputs=True, retain_graph=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty
use_cuda = False
BATCH_SIZE=256
LAMBDA = 0.1
DIM = 512
noise = torch.randn(BATCH_SIZE, 2)
if use_cuda:
noise = noise.cuda()
noisev = autograd.Variable(noise)
noise1 = torch.randn(BATCH_SIZE, 2)
if use_cuda:
noise1 = noise1.cuda()
noise1v = autograd.Variable(noise1)
netD = nn.Sequential(
nn.Linear(2, DIM),
nn.ReLU(True),
nn.Linear(DIM, DIM),
nn.ReLU(True),
nn.Linear(DIM, DIM),
nn.ReLU(True),
nn.Linear(DIM, 1),
)
netD.zero_grad()
print netD
gp = calc_gradient_penalty(netD, noisev.data, noise1v.data)
gp.backward()
Then I got Runtime Error
RuntimeErrorTraceback (most recent call last)
<ipython-input-108-e8fda420b53c> in <module>()
27 # print p.grad
28 gp = calc_gradient_penalty(netD, noisev.data, noise1v.data)
---> 29 gp.backward()
30 # for p in netD.parameters():
31 # print p.grad
/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_variables)
142 raise TypeError("gradient has to be a Tensor, Variable or None")
143 gradient = Variable(gradient, volatile=True)
--> 144 self._execution_engine.run_backward((self,), (gradient,), retain_variables)
145
146 def register_hook(self, hook):
RuntimeError: Threshold is not differentiable twice
So how can I solve this?
Metadata
Metadata
Assignees
Labels
No labels