Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Jun 25, 2019

As part of the Variable/Tensor merge, variable.tensor_data() should be removed in favor of variable.variable_data() (which has the same semantics as Python tensor.data). This PR removes tensor_data() call sites in torch/csrc/autograd / torch/csrc/cuda.

@pytorchbot pytorchbot added caffe2 oncall: jit Add this issue/PR to JIT oncall triage queue module: autograd Related to torch.autograd, and the autograd engine in general module: cuda Related to torch.cuda, and CUDA support in general oncall: distributed Add this issue/PR to distributed oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: numpy Related to numpy support, and also numpy compatibility of our operators module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) labels Jun 25, 2019
@yf225 yf225 removed request for apaszke, mrshenli and pietern June 25, 2019 00:28
Will Feng added 2 commits June 24, 2019 21:48
@yf225 yf225 changed the title [WIP] Remove tensor_data() call sites, and rename it to _tensor_data_deprecated() Remove tensor_data() call sites, and rename it to _tensor_data_deprecated() Jun 25, 2019
@yf225 yf225 force-pushed the remove_tensor_data_callsites branch from e2e6312 to 14bfbd7 Compare June 25, 2019 01:58
@yf225 yf225 changed the title Remove tensor_data() call sites, and rename it to _tensor_data_deprecated() Remove some of tensor_data() call sites Jun 25, 2019
@yf225 yf225 changed the title Remove some of tensor_data() call sites Remove tensor_data() call sites in torch/csrc/autograd/ and distributed Jun 25, 2019
@yf225 yf225 changed the title Remove tensor_data() call sites in torch/csrc/autograd/ and distributed Remove tensor_data() call sites in torch/csrc/autograd and distributed Jun 25, 2019
@yf225 yf225 mentioned this pull request Jun 25, 2019
22 tasks
@yf225 yf225 changed the title Remove tensor_data() call sites in torch/csrc/autograd and distributed Remove tensor_data() call sites in torch/csrc/autograd and torch/csrc/cuda Jun 25, 2019
// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
Variable var = as_variable_ref(data_).variable_data();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't you already set this to variable_data above? So you can just skip this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unclear whether we have invariant that the SavedVariable will only be unpacked once. (I tried removing detach/variable_data here and all tests pass.) If we do have that invariant, we should always use detach/variable_data here, so that the unpacked Variable is always a shallow copy of the SavedVariable.

auto data = as_variable_ref(r.tensor(1)).tensor_data();
auto var = make_variable(data, r.toBool(2));
auto data = as_variable_ref(r.tensor(1)).variable_data();
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two points:

  1. Is there any reason this can't use re-use the version counter, i.e. detach semantics rather than data?

  2. If we are going with detach semantics, passing a boolean to allow the metadata change seems reasonable -- have we not done a release yet where the "allow_metadata_change" is in?

namespace torch { namespace autograd {

variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
variable_list wrap_outputs(const variable_list& inputs, variable_list&& outputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's going on here, can you explain?

This is only called for legacy_apply functions? These should always be used with variable_data for some reason?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called for legacy_apply functions and also DelayedError::apply. The main reason we need to change this from tensor_list to variable_list is because outputs are Variables now instead of Tensors (because we now use .detach() instead of .tensor_data() to build up the outputs list in legacy_apply and DelayedError::apply).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that we had tensors and now we have variables, I'm trying to understand what the constraints of this code actually are. So this is used with legacy apply, which we want to get rid of, right? (#16947). How does legacy apply actually work? The forwards is wrapped in a no-grad block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For legacy autograd functions, the forward is wrapped in a no-grad block

{
AutoGradMode grad_mode(false);
THPObjectPtr forward_fn(PyObject_GetAttrString((PyObject*)self, "forward"));
if (!forward_fn) return nullptr;
raw_output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
if (!raw_output) return nullptr;
}
legacy_apply is only used in legacy autograd function's backward pass
// NOTE: this function is written in a way that assumes it's only called for backward;
// it's used by engine.cpp. This is responsible for forwarding a call from
// C++'s Function::apply to a Python method "apply".
auto PyFunction::apply(variable_list&& inputs) -> variable_list {
AutoGIL gil;
at::OptionalDeviceGuard _device_guard;
THPFunction* py_fn = (THPFunction*)obj;
THPObjectPtr _legacy(PyObject_GetAttrString(obj, "_is_legacy"));
if (_legacy == Py_True) {
return legacy_apply(inputs);
}

In legacy_apply, we have to make a shallow-copy of the output from the backward function (via detach or variable_data) before calling wrap_outputs, because for a backward function that directly pass grad_x through:

class PassthroughFunction(Function):
    ...
    def backward(self, grad_x):
        return grad_x

if grad_x requires grad, we will attach a new gradient edge to grad_x (via autograd::create_gradient_edge(output, grad_fn)), which will replace any previous gradient edge that grad_x has (which is an incorrect behavior). Hence we should make a shallow-copy of the output from the backward function in legacy_apply before calling wrap_outputs.

// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
Variable var = as_variable_ref(data_).detach();
Copy link
Contributor Author

@yf225 yf225 Jul 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unclear whether we have invariant that the SavedVariable will only be unpacked once. (I tried removing detach/variable_data here and all tests pass.) If we do have that invariant, we should always use detach/variable_data here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would you figure out if we have that invariant or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following use case, a SavedVariable can be unpacked more than once:

>>> a = torch.zeros(2, 2).fill_(3).requires_grad_()
>>> b = torch.zeros(2, 2).fill_(5).requires_grad_()
>>> c = a*b
# SavedVariable being saved:  3  3
#  3  3
# [ CPUFloatType{2,2} ]
# SavedVariable being saved:  5  5
#  5  5
# [ CPUFloatType{2,2} ]

>>> d = c.sum()
>>> d.backward(retain_graph=True)
# SavedVariable being unpacked:  3  3
#  3  3
# [ CPUFloatType{2,2} ]
# SavedVariable being unpacked:  5  5
#  5  5
# [ CPUFloatType{2,2} ]

>>> d.backward()
# SavedVariable being unpacked:  3  3
#  3  3
# [ CPUFloatType{2,2} ]
# SavedVariable being unpacked:  5  5
#  5  5
# [ CPUFloatType{2,2} ] 

In the above example, a and b are unpacked more than once, hence we should always use detach/variable_data to make a shallow-copy of the SavedVariable here, so that each unpacking does not affect the original SavedVariable.

AT_ASSERT(t.is_variable());
Variable var = t;
device_outputs.push_back(make_variable(var.tensor_data(), false));
Variable var = as_variable_ref(t).variable_data();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We specifically want to use variable_data instead of detach here, because the comment above this function in NOTE [ Version Counter in comm.*_coalesced ] mentions:

// We thus re-wrap these Variables after broadcasting (i.e., effetively doing
// what is equivalent to .data in Python), and give them individual version
// counters. ...

variable_data is the strict equivalent of .data in Python, which creates new version counter for the returned tensor, and is exactly what we want.

AT_ASSERT(t.is_variable());
Variable var = t;
device_outputs.push_back(make_variable(var.tensor_data(), false));
Variable var = as_variable_ref(t).variable_data();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comment above.

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python_legacy_variable and python_variable changes look fine (and you can break them out separately if you want), but they should have clear BC commit messages.

I still don't really understand what is going on with the legacy apply or comm primitive stuff.

@yf225
Copy link
Contributor Author

yf225 commented Jul 12, 2019

I moved python_legacy_variable and python_variable changes to another PR: #22821.

@pytorchbot
Copy link
Collaborator

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
Stale pull requests will automatically be closed 30 days after being marked Stale

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 cla signed module: autograd Related to torch.autograd, and the autograd engine in general module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: numpy Related to numpy support, and also numpy compatibility of our operators module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: jit Add this issue/PR to JIT oncall triage queue Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants