Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Mar 29, 2019

Fixes #17962

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 29, 2019
@wanchaol wanchaol force-pushed the autogradzero branch 2 times, most recently from bb43f7c to 657f44b Compare March 29, 2019 22:19
@wanchaol wanchaol marked this pull request as ready for review March 29, 2019 22:20
@wanchaol wanchaol requested a review from ailzhang March 29, 2019 22:20
@wanchaol wanchaol changed the title [WIP] Fix contiguous AD and Autogradzero inconsistency Fix contiguous AD and Autogradzero inconsistency Mar 29, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang
Copy link
Contributor

ailzhang commented Apr 1, 2019

The test change for AD support has been merged, would you mind rebasing and see if all tests passed? Thanks a lot!

test/test_jit.py Outdated
def foo(x):
return 3 + x.contiguous()

x = torch.rand(1, dtype=torch.float, requires_grad=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is contiguous here, we want to test against both when x is /isn't contiguous I think?

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!! Please land if CI is green.

apaszke
apaszke previously requested changes Apr 2, 2019
if self.is_contiguous():
return grad_output
else:
return grad_output.clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should just return grad_output. If an op downstream needs it to be contiguous, it will do it on its own.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if we simply return grad_output, it should be fine to make the test simpler.

test/test_jit.py Outdated

grad = torch.randn(5, 5, dtype=torch.float)
out.backward(grad)
self.assertEqual(x.grad, grad.transpose(1, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but what are we testing here? This should be covered as part of autograd tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are testing the AD formula, the autograd test does not cover the Ad formula I believe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, please add an autograd test! It's a differentiable op just like any other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep just added them to autograd tests :)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@wanchaol wanchaol dismissed apaszke’s stale review April 3, 2019 18:02

fixed the AD formula to only return grad_output, and make test covered in autograd

@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in a21e256.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JIT] state[input] != State::Unknown ASSERT FAILED at /pytorch/torch/csrc/jit/passes/specialize_autogradzero.cpp:57

4 participants