Skip to content

Conversation

@caogang
Copy link
Contributor

@caogang caogang commented May 8, 2017

For high order grad support in new-style function, solving the issue #1483 . I have finished the features:

  • Threshold
  • ReLU
  • Sqrt
  • Norm

@caogang caogang changed the title Add high order grad support for ReLU and Threshold operator [WIP] Add high order grad support for Some operator May 8, 2017
@caogang caogang changed the title [WIP] Add high order grad support for Some operator Add high order grad support for Some operator May 8, 2017
if ctx.dim is None:
input, = ctx.saved_variables
if ctx.norm_type == 2:
scale = (grad_output[0] / ctx.norm).data[0]

This comment was marked as off-topic.

scale = grad_output[0] / self.norm ** (self.norm_type - 1)
return input.mul(pow).mul(scale)
pow = input.abs().pow(ctx.norm_type - 2)
scale = (grad_output[0] / ctx.norm ** (ctx.norm_type - 1)).data[0]

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great! Thanks for the PR

if ctx.dim is None:
input, = ctx.saved_variables
if ctx.norm_type == 2:
scale_v = (grad_output[0] / ctx.norm).expand_as(input)

This comment was marked as off-topic.

scale = grad_output[0] / self.norm ** (self.norm_type - 1)
return input.mul(pow).mul(scale)
pow = input.abs().pow(ctx.norm_type - 2)
scale_v = (grad_output[0] / ctx.norm ** (ctx.norm_type - 1)).expand_as(input)

This comment was marked as off-topic.

self.norm_type = norm_type
self.dim = dim
@staticmethod
def forward(ctx, input, norm_type=2, dim=None):

This comment was marked as off-topic.

return grad_input.mul(-1)


class Threshold(Function):

This comment was marked as off-topic.

)
else:
mask = input > ctx.threshold
grad_input = mask.type_as(grad_output) * grad_output

This comment was marked as off-topic.

@caogang
Copy link
Contributor Author

caogang commented May 9, 2017

@apaszke Thank you for your suggestions

All suggestions besides the last one grad_input = grad_output.masked_fill(input > ctx.threshold, 0) have been done in the new commit.

I have tried using mask_fill before, but I find it doesn't work when performing twice backward.

Here is my test code:

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda() if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    if use_cuda:
        interpolates = interpolates.cuda()
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(disc_interpolates.size()),
                              create_graph=True, only_inputs=True, retain_graph=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

use_cuda = False
BATCH_SIZE=256
LAMBDA = 0.1
DIM = 256
noise = torch.neg(torch.randn(BATCH_SIZE, 2))
if use_cuda:
    noise = noise.cuda()
noisev = autograd.Variable(noise)

noise1 = torch.randn(BATCH_SIZE, 2)
if use_cuda:
    noise1 = noise1.cuda()
noise1v = autograd.Variable(noise1)

netD = nn.Sequential(
            nn.Linear(2, DIM),
            nn.ReLU(True),
            nn.Linear(DIM, DIM),
            nn.ReLU(True),
            nn.Linear(DIM, DIM),
            nn.ReLU(True),
            nn.Linear(DIM, 1),
        )
netD.zero_grad()
print netD
gp = calc_gradient_penalty(netD, noisev.data, noise1v.data)
gp.backward()
for p in netD.parameters():
    print p.grad

Then I get an error:

NotImplementedErrorTraceback (most recent call last)
<ipython-input-2-a841c9970f31> in <module>()
     44 print netD
     45 gp = calc_gradient_penalty(netD, noisev.data, noise1v.data)
---> 46 gp.backward()
     47 for p in netD.parameters():
     48     print p.grad

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_variables)
    150                 raise TypeError("gradient has to be a Tensor, Variable or None")
    151             gradient = Variable(gradient, volatile=True)
--> 152         self._execution_engine.run_backward((self,), (gradient,), retain_variables)
    153 
    154     def register_hook(self, hook):

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/function.pyc in backward(*grad_outputs)
    170         be the gradient w.r.t. the corresponding input.
    171         """
--> 172         raise NotImplementedError
    173 
    174 

NotImplementedError: 

The only difference from current commit is the modification of grad_input = grad_output.masked_fill(input > ctx.threshold, 0)
I have no idea where is wrong. So I still use grad_input = mask.type_as(grad_output) * grad_output

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into that masked_fill issue and it turns out that it was a bug (I've included a fix in #1506). Can you please change it as I said? The tests will fail now, but should be ok once my PR is merged.

value,
inplace
)
return output

This comment was marked as off-topic.

This comment was marked as off-topic.

@caogang
Copy link
Contributor Author

caogang commented May 9, 2017

Ok, changes have been done. I want to know whether the bug #1517 will be also fixed in your pull request #1506 @apaszke

@apaszke
Copy link
Contributor

apaszke commented May 9, 2017

Yes, I'll push that soon as well

@apaszke
Copy link
Contributor

apaszke commented May 10, 2017

Can you please fix the conflicts? My branch is merged now.

* master: (26 commits)
  Fix Linear function
  Fix comparison functions
  Expose variable attribute of AccumulateGrad
  Don't modify non-volatile grads in zero_grad
  Minor fix in Prod backward
  Add new flags to Variable.backward
  Replace retain_variables with retain_graph
  Improve output wrapping logic in autograd
  Remove spurious memo argument in Module.parameters() (pytorch#1527)
  Make torch.cat not synchronize the host and device
  Reference counting documentation. (pytorch#1520)
  Restore examples with keepdim=True default.
  Explicitly pass keepdim=False for tests that require it.
  Change keepdim default to False.
  Fix test_normalize NN test.
  Add a keepdim test to torch_test.
  Make (non-legacy) nn backwards compatible.
  Add autograd tests for keepdim
  Add documentation for keepdim.
  Change all legacy/nn modules to use keepdim=True (even if tests don't fail).
  ...

# Conflicts:
#	torch/autograd/_functions/reduce.py
#	torch/autograd/variable.py
@caogang
Copy link
Contributor Author

caogang commented May 11, 2017

Hi @apaszke , current commit maybe still something wrong with the mask_fill. Because it works fine using mask.type_as(grad_output) * grad_output

With the same test code above. I got another error

TypeErrorTraceback (most recent call last)
<ipython-input-3-a841c9970f31> in <module>()
     44 print netD
     45 gp = calc_gradient_penalty(netD, noisev.data, noise1v.data)
---> 46 gp.backward()
     47 for p in netD.parameters():
     48     print p.grad

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_graph, create_graph, retain_variables)
    143             Defaults to False, unless ``gradient`` is a volatile Variable.
    144         """
--> 145         torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
    146 
    147     def register_hook(self, hook):

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/__init__.pyc in backward(variables, grad_variables, retain_graph, create_graph, retain_variables)
     96 
     97     Variable._execution_engine.run_backward(
---> 98         variables, grad_variables, retain_graph)
     99 
    100 

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/function.pyc in apply(self, *args)
     88 
     89     def apply(self, *args):
---> 90         return self._forward_cls.backward(self, *args)
     91 
     92 

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/nn/_functions/linear.pyc in backward(ctx, grad_output)
     21         grad_input = grad_weight = grad_bias = None
     22         if ctx.needs_input_grad[0]:
---> 23             grad_input = torch.mm(grad_output, weight)
     24         if ctx.needs_input_grad[1]:
     25             grad_weight = torch.mm(grad_output.t(), input)

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/variable.pyc in mm(self, matrix)
    526     def mm(self, matrix):
    527         output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
--> 528         return self._static_blas(Addmm, (output, 0, 1, self, matrix), False)
    529 
    530     def bmm(self, batch):

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/variable.pyc in _static_blas(cls, args, inplace)
    519         if num_args == 4:
    520             alpha = args[1]
--> 521         return cls.apply(*(args[:1] + args[-2:] + (alpha, beta, inplace)))
    522 
    523     def _blas(self, cls, args, inplace):

/home/users/gang.cao/env/lib/python2.7/site-packages/torch/autograd/_functions/blas.pyc in forward(ctx, add_matrix, matrix1, matrix2, alpha, beta, inplace)
     22         output = _get_output(ctx, add_matrix, inplace=inplace)
     23         return torch.addmm(alpha, add_matrix, beta,
---> 24                            matrix1, matrix2, out=output)
     25 
     26     @staticmethod

TypeError: torch.addmm received an invalid combination of arguments - got (int, torch.ByteTensor, int, torch.ByteTensor, torch.FloatTensor, out=torch.ByteTensor), but expected one of:
 * (torch.ByteTensor source, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
      didn't match because some of the arguments have invalid types: (int, torch.ByteTensor, int, torch.ByteTensor, !torch.FloatTensor!, out=torch.ByteTensor)
 * (int beta, torch.ByteTensor source, int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
      didn't match because some of the arguments have invalid types: (int, torch.ByteTensor, int, !torch.ByteTensor!, !torch.FloatTensor!, out=torch.ByteTensor)

@caogang
Copy link
Contributor Author

caogang commented May 13, 2017

Hi, @apaszke when will this PR be merged?

@apaszke
Copy link
Contributor

apaszke commented May 13, 2017

I've been busy working on other things. I'll try to review it this weekend

@apaszke
Copy link
Contributor

apaszke commented May 14, 2017

@pytorchbot test this please

@apaszke apaszke merged commit 0ba2043 into pytorch:master May 14, 2017
@apaszke
Copy link
Contributor

apaszke commented May 14, 2017

Thank you!

Jiaming-Liu pushed a commit to Jiaming-Liu/pytorch that referenced this pull request May 18, 2017
@EthanZhu90
Copy link

Hi, @caogang @apaszke, I am trying the WGAN-GP implementation with pytorch installed from the latest git, but still got the same error @caogang. Any idea how to solve it??
Thank you in advance.

TypeError: torch.addmm received an invalid combination of arguments - got (int, torch.ByteTensor, int, torch.ByteTensor, torch.FloatTensor, out=torch.ByteTensor), but expected one of:
 * (torch.ByteTensor source, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (torch.ByteTensor source, int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
 * (int beta, torch.ByteTensor source, int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
      didn't match because some of the arguments have invalid types: (int, torch.ByteTensor, int, torch.ByteTensor, !torch.FloatTensor!, out=torch.ByteTensor)
 * (int beta, torch.ByteTensor source, int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2, *, torch.ByteTensor out)
      didn't match because some of the arguments have invalid types: (int, torch.ByteTensor, int, !torch.ByteTensor!, !torch.FloatTensor!, out=torch.ByteTensor)

@caogang
Copy link
Contributor Author

caogang commented May 31, 2017

@EthanZhu90 , this may be a bug existing in current branch.
So I can give you an temp fix to make this error clear before this bug is solved.
You can change the source code, and recompile it. This error will be clear.

torch/nn/_functions/thnn/activation.py

         else:
+            mask = input > ctx.threshold
+            grad_input = mask.type_as(grad_output) * grad_output
-            grad_input = grad_output.masked_fill(input > ctx.threshold, 0)
         return grad_input, None, None, None

@caogang caogang deleted the develop branch June 8, 2017 08:11
houseroad added a commit to houseroad/pytorch that referenced this pull request Oct 16, 2018
…d49783 (pytorch#12676)

Summary:
Pull Request resolved: pytorch#12676

Previous import was 06f6d63d5529e3a94533c9f34c402be1793420b1

Included changes:
- **[1cbe274](onnx/onnx@1cbe274)**: fix the optimizer (pytorch#1510) <Lu Fang>
- **[481ad99](onnx/onnx@481ad99)**: Fix TensorProto int32_data comment (pytorch#1509) <Lutz Roeder>
- **[f04fbe0](onnx/onnx@f04fbe0)**: fix ninja external (pytorch#1507) <Rui Zhu>

Reviewed By: jamesr66a, wanchaol

Differential Revision: D10388438

fbshipit-source-id: ebc67073ca64daae0591873fcfeadc9885308ef5
facebook-github-bot pushed a commit that referenced this pull request Oct 16, 2018
…d49783 (#12676)

Summary:
Pull Request resolved: #12676

Previous import was 06f6d63d5529e3a94533c9f34c402be1793420b1

Included changes:
- **[1cbe274](onnx/onnx@1cbe274)**: fix the optimizer (#1510) <Lu Fang>
- **[481ad99](onnx/onnx@481ad99)**: Fix TensorProto int32_data comment (#1509) <Lutz Roeder>
- **[f04fbe0](onnx/onnx@f04fbe0)**: fix ninja external (#1507) <Rui Zhu>

Reviewed By: jamesr66a, wanchaol

Differential Revision: D10388438

fbshipit-source-id: 298100589ce226c63d4e58edf185c9227fd52c85
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants