-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make JIT attributes t_ and ts_ store Variable instead of Tensor #16596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes in line, otherwise this looks good, pending CI.
| } | ||
| if (ref.is_variable()) { | ||
| ref = autograd::Variable(ref).data(); | ||
| if (!ref.is_variable()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer we just AT_CHECK(ref.is_variable()) and fix the places that pass something else, but this is ok for now if you put the rest in a follow up patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment in code:
// TODO: fix all cases where we are not passing in a variable,
// and then change this to an AT_ASSERT
torch/csrc/jit/constants.cpp
Outdated
| if (!ref.is_variable()) { | ||
| ref = autograd::make_variable(ref, /*requires_grad=*/false); | ||
| } else { | ||
| ref.set_requires_grad(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no one should be passing a requires_grad tensor here, so this should be an assert as well. I don't want to silently fix places like this, because it covers up larger logic errors in how constants are created. If it requires_grad it is also likely someone else might mutate it, which would be bad as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently test_call_python_mod_from_tracing_fn / test_call_script_mod_from_tracing_fn / test_call_traced_mod_from_tracing_fn are failing if we assert !ref.requires_grad() here. Sample error:
======================================================================
ERROR: test_call_python_mod_from_tracing_fn (__main__.TestScript)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test/test_jit.py", line 7676, in test_call_python_mod_from_tracing_fn
@_trace(torch.rand(3, 4))
File "test/test_jit.py", line 231, in wrapper
return torch.jit.trace(func, args, **kwargs)
File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/jit/__init__.py", line 637, in trace
var_lookup_fn, _force_outplace)
File "test/test_jit.py", line 7678, in traced_fn
return pm(x) + 1.0
File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/nn/modules/module.py", line 490, in __call__
result = self._slow_forward(*input, **kwargs)
File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/nn/modules/module.py", line 480, in _slow_forward
result = self.forward(*input, **kwargs)
File "test/test_jit.py", line 7672, in forward
return torch.mm(x, self.param)
RuntimeError: !ref.requires_grad() ASSERT FAILED at ../torch/csrc/jit/constants.cpp:28, please report a bug to PyTorch.
Here is the test code:
def test_call_python_mod_from_tracing_fn(self):
class PythonMod(torch.nn.Module):
def __init__(self):
super(PythonMod, self).__init__()
# PROBLEM: this has requires_grad=True by default
self.param = torch.nn.Parameter(torch.rand(4, 3))
def forward(self, x):
# PROBLEM: since self.param.requires_grad=True,
# the `!ref.requires_grad()` check during constant insertion fails
return torch.mm(x, self.param)
pm = PythonMod()
@_trace(torch.rand(3, 4))
def traced_fn(x):
return pm(x) + 1.0
# Note: the parameter self.param from the Python module is inlined
# into the graph
self.assertExpected(canonical(traced_fn.graph))I think the original intention is to insert self.param as constant into the graph and ignore its gradients, and we can use ref = autograd::make_variable(autograd::Variable(ref).data(), /*requires_grad=*/false) here to emulate this behavior (which creates a new variable that contains the same data storage as the original variable, with requires_grad=false). Do you think this would be a good idea here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think this test is bogus. We should not be silently turning parameters into constants. I think what should happen here is a nice user-facing error that says that code is attempting to use a gradient recording tensor as a constant in a trace, and that the user likely wanted to trace the module rather than a function calling the module. I think it is arguable that we should not be implicitly turning anything into constants, but that is a larger change to consider.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zdevito Got it and I can remove the test_call_python_mod_from_tracing_fn test.
The other two failing tests look like the following:
def test_call_script_mod_from_tracing_fn(self):
class ScriptMod(torch.jit.ScriptModule):
def __init__(self):
super(ScriptMod, self).__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
@torch.jit.script_method
def forward(self, x):
return torch.mm(x, self.param)
sm = ScriptMod()
@_trace(torch.rand(3, 4))
def traced_fn(x):
return sm(x) + 1.0
self.assertExpected(canonical(traced_fn.graph))def test_call_traced_mod_from_tracing_fn(self):
class TracedModule(torch.nn.Module):
def __init__(self):
super(TracedModule, self).__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
def forward(self, x):
return torch.mm(x, self.param)
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
@_trace(torch.rand(3, 4))
def traced_fn(x):
return tm(x) + 1.0
# Note: the parameter self.param from the Python module is inlined
# into the graph
self.assertExpected(canonical(traced_fn.graph))These two tests have the same issue in that self.param = torch.nn.Parameter(torch.rand(4, 3)) has requires_grad=True by default (and if I manually set self.param to requires_grad=False, the test will pass). Are the use cases in these two tests valid / should we also remove them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep these tests because they are checking that some interactions between tracing and script work. However, it is fine to explicitly set requires_grad=False to make them pass. That is not what they were testing before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might be true of the first test as well. We can just turn of the requires_grad to make them all pass.
df4edcc to
c3815ba
Compare
028c12e to
7994b41
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Discussed with @zdevito and we want to use Variable (with
set_requires_grad(false)) instead of Tensor in all parts of JIT, to eliminate the distinction and the conceptual overhead when trying to figure out which one to use.This also helps with the Variable/Tensor merge work tracked at #13638, which will make common functions (such as
numel()/sizes()/dim()) on Variable much faster when finished.