Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Jan 31, 2019

Discussed with @zdevito and we want to use Variable (with set_requires_grad(false)) instead of Tensor in all parts of JIT, to eliminate the distinction and the conceptual overhead when trying to figure out which one to use.

This also helps with the Variable/Tensor merge work tracked at #13638, which will make common functions (such as numel() / sizes() / dim()) on Variable much faster when finished.

@yf225 yf225 requested review from gchanan and zdevito January 31, 2019 04:03
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 31, 2019
@yf225 yf225 mentioned this pull request Jan 31, 2019
22 tasks
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes in line, otherwise this looks good, pending CI.

}
if (ref.is_variable()) {
ref = autograd::Variable(ref).data();
if (!ref.is_variable()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer we just AT_CHECK(ref.is_variable()) and fix the places that pass something else, but this is ok for now if you put the rest in a follow up patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment in code:

// TODO: fix all cases where we are not passing in a variable,
// and then change this to an AT_ASSERT

if (!ref.is_variable()) {
ref = autograd::make_variable(ref, /*requires_grad=*/false);
} else {
ref.set_requires_grad(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no one should be passing a requires_grad tensor here, so this should be an assert as well. I don't want to silently fix places like this, because it covers up larger logic errors in how constants are created. If it requires_grad it is also likely someone else might mutate it, which would be bad as well.

Copy link
Contributor Author

@yf225 yf225 Feb 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently test_call_python_mod_from_tracing_fn / test_call_script_mod_from_tracing_fn / test_call_traced_mod_from_tracing_fn are failing if we assert !ref.requires_grad() here. Sample error:

======================================================================
ERROR: test_call_python_mod_from_tracing_fn (__main__.TestScript)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/test_jit.py", line 7676, in test_call_python_mod_from_tracing_fn
    @_trace(torch.rand(3, 4))
  File "test/test_jit.py", line 231, in wrapper
    return torch.jit.trace(func, args, **kwargs)
  File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/jit/__init__.py", line 637, in trace
    var_lookup_fn, _force_outplace)
  File "test/test_jit.py", line 7678, in traced_fn
    return pm(x) + 1.0
  File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/nn/modules/module.py", line 490, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/data/miniconda3/envs/working/lib/python3.7/site-packages/torch/nn/modules/module.py", line 480, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "test/test_jit.py", line 7672, in forward
    return torch.mm(x, self.param)
RuntimeError: !ref.requires_grad() ASSERT FAILED at ../torch/csrc/jit/constants.cpp:28, please report a bug to PyTorch.

Here is the test code:

def test_call_python_mod_from_tracing_fn(self):
    class PythonMod(torch.nn.Module):
        def __init__(self):
            super(PythonMod, self).__init__()
            # PROBLEM: this has requires_grad=True by default
            self.param = torch.nn.Parameter(torch.rand(4, 3))

        def forward(self, x):
            # PROBLEM: since self.param.requires_grad=True,
            # the `!ref.requires_grad()` check during constant insertion fails
            return torch.mm(x, self.param)

    pm = PythonMod()

    @_trace(torch.rand(3, 4))
    def traced_fn(x):
        return pm(x) + 1.0

    # Note: the parameter self.param from the Python module is inlined
    # into the graph
    self.assertExpected(canonical(traced_fn.graph))

I think the original intention is to insert self.param as constant into the graph and ignore its gradients, and we can use ref = autograd::make_variable(autograd::Variable(ref).data(), /*requires_grad=*/false) here to emulate this behavior (which creates a new variable that contains the same data storage as the original variable, with requires_grad=false). Do you think this would be a good idea here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this test is bogus. We should not be silently turning parameters into constants. I think what should happen here is a nice user-facing error that says that code is attempting to use a gradient recording tensor as a constant in a trace, and that the user likely wanted to trace the module rather than a function calling the module. I think it is arguable that we should not be implicitly turning anything into constants, but that is a larger change to consider.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zdevito Got it and I can remove the test_call_python_mod_from_tracing_fn test.
The other two failing tests look like the following:

def test_call_script_mod_from_tracing_fn(self):
    class ScriptMod(torch.jit.ScriptModule):
        def __init__(self):
            super(ScriptMod, self).__init__()
            self.param = torch.nn.Parameter(torch.rand(4, 3))

        @torch.jit.script_method
        def forward(self, x):
            return torch.mm(x, self.param)

    sm = ScriptMod()

    @_trace(torch.rand(3, 4))
    def traced_fn(x):
        return sm(x) + 1.0

    self.assertExpected(canonical(traced_fn.graph))
def test_call_traced_mod_from_tracing_fn(self):
    class TracedModule(torch.nn.Module):
        def __init__(self):
            super(TracedModule, self).__init__()
            self.param = torch.nn.Parameter(torch.rand(4, 3))

        def forward(self, x):
            return torch.mm(x, self.param)

    tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))

    @_trace(torch.rand(3, 4))
    def traced_fn(x):
        return tm(x) + 1.0

    # Note: the parameter self.param from the Python module is inlined
    # into the graph
    self.assertExpected(canonical(traced_fn.graph))

These two tests have the same issue in that self.param = torch.nn.Parameter(torch.rand(4, 3)) has requires_grad=True by default (and if I manually set self.param to requires_grad=False, the test will pass). Are the use cases in these two tests valid / should we also remove them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep these tests because they are checking that some interactions between tracing and script work. However, it is fine to explicitly set requires_grad=False to make them pass. That is not what they were testing before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be true of the first test as well. We can just turn of the requires_grad to make them all pass.

@yf225 yf225 force-pushed the t_ts_store_variable branch from df4edcc to c3815ba Compare February 2, 2019 19:33
@yf225 yf225 force-pushed the t_ts_store_variable branch from 028c12e to 7994b41 Compare February 6, 2019 22:40
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants