Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Feb 13, 2019

As part of the Variable/Tensor merge work: #13638, we make the following changes in this PR:

  1. Remove the Variable::Impl class and the DifferentiableViewImpl class
  2. Change all Variable.data() call sites to either use Variable directly, or use Variable.tensor_data()
  3. Remove Variable.data() API
  4. Add Variable.variable_data() that matches tensor.data in Python API, which creates a new Variable that shares the same storage and tensor metadata with the original Variable, but with a completely new autograd history.

After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its impl_. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.

Note that this PR is BC-breaking in the following use cases:

Use Case 1:
Previously, x.data = y works even if x and y are of different TensorImpl type (e.g. x is a CPU dense tensor whose impl is of type TensorImpl, while y is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, x.data = y doesn't work anymore if x and y are of different TensorImpl type, because the underlying implementation variable.set_data(tensor) no longer works if variable and tensor have different TensorImpl type.

This especially shows up in the following use case:

class TestModule(nn.Module):
    def __init__(self):
        super(TestModule, self).__init__()
        self.fc1 = nn.Linear(20, 10)

m = TestModule()

# Under the hood, `m._apply()` uses `.data =` to change the data of `m`'s parameters and their gradients
m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20])))
# After this PR, this fails with "RuntimeError: Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have different types of TensorImpl."

Use Case 2:
If a tensor x's grad is sparse, accumulating dense gradients to x will change the tensor that x.grad is pointing to. This is better illustrated with the following example:

params = torch.tensor([1.5, 1.5]).requires_grad_()
# Change gradient to a sparse tensor
params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))

grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad)  # This will fail after this PR

The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the params.grad tensor reference.

@yf225 yf225 requested review from ezyang and gchanan February 13, 2019 23:33
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Feb 13, 2019
AT_ASSERT(!indices.is_variable() && !values.is_variable()); // They should be plain tensors! // TODO: change this to check `.requires_grad()` and `GradMode::is_enabled()` when Variable and Tensor are merged
AT_ASSERT((!(indices.is_variable() && indices.requires_grad()) &&
!(values.is_variable() && values.requires_grad()))
|| at::NonVariableTypeMode::is_enabled()); // TODO: use `compute_requires_grad()` after it's moved to ATen
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will work on moving compute_requires_grad() to ATen in the next PR.

@yf225 yf225 mentioned this pull request Feb 13, 2019
22 tasks
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

with torch.no_grad():
param = fn(param)
if param._grad is not None:
param._grad = fn(param._grad)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to debug why we need this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: For some test cases, param._grad is already a tensor with allow_tensor_metadata_change set to false, and changing its storage in set_data() will throw an error. To fix this problem, we should just do param._grad = fn(param._grad), and use with torch.no_grad() to avoid accumulating gradients.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New code is more idiomatic anyway, sgtm.

*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
*/
virtual void copy_tensor_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few issues with this function:

  1. this is just used by the public API functions shallow_copy_and_detach, shallow_copy_from, right? Then it should be private.
  2. This doesn't need to be virtual; actually notice you don't ever use this, so this doesn't even need to be a non-static method.
  3. because it's a static method and you aren't using virtual dispatch, you can use the correct static types of inputs. So you don't have to do scary static_casts in the derived types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

err, you probably need to make it protected so the derived types can call it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

* Shallow-copies data from another TensorImpl into this TensorImpl.
*/
virtual void shallow_copy_from(c10::intrusive_ptr<TensorImpl> impl) {
copy_tensor_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. it's scary not to have an ASSERT in here -- I know you check the only call site to give a nice error message, but still, you want to guard the unsafe places.
  2. if you follow the arguments above that copy_tensor_data should have correct static types, you should do the static cast of impl in here, after the ASSERT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the static cast and the ASSERT in the TensorImpl derived types' shallow_copy_from().

*
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
*/
static void copy_tensor_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it not work to make it protected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to refresh numel here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

/*src_impl=*/opaque_impl,
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to refresh numel here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 24, 2019
Summary:
As part of the Variable/Tensor merge work: pytorch/pytorch#13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.

After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.

**Note that this PR is BC-breaking in the following use cases:**

**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.

**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
    # Change gradient to a sparse tensor
    params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))

grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad)  # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: pytorch/pytorch#17072

Differential Revision: D14075257

Pulled By: yf225

fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
@dzhulgakov
Copy link
Collaborator

Wow, epic work!!!

@ezyang
Copy link
Contributor

ezyang commented May 24, 2019

This broke XLA build. cc @ailzhang

@colesbury
Copy link
Member

Hooray!

@yf225 yf225 mentioned this pull request May 30, 2019
facebook-github-bot pushed a commit that referenced this pull request Jun 15, 2019
Summary:
After #17072, we are allowed to pass Variables into ATen ops, thus there is no need to unwrap input variables in the c10 call path.

Note that since Caffe2 still expects inputs to be pure Tensors, we moved the unwrapping logic to the Caffe2 wrapper.
Pull Request resolved: #21620

Differential Revision: D15763560

Pulled By: yf225

fbshipit-source-id: 5375f0e51eb320f380ae599ebf98e6b259f0bff8
facebook-github-bot pushed a commit that referenced this pull request Jun 19, 2019
…flag, and check it in `nn.Module._apply()` (#21613)

Summary:
#17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...
```

yf225 TODO: fix the description here when we finish the implementation

To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.

We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. resistor ailzhang
Pull Request resolved: #21613

Differential Revision: D15895387

Pulled By: yf225

fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 19, 2019
…flag, and check it in `nn.Module._apply()` (#21613)

Summary:
pytorch/pytorch#17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...
```

yf225 TODO: fix the description here when we finish the implementation

To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.

We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. resistor ailzhang
Pull Request resolved: pytorch/pytorch#21613

Differential Revision: D15895387

Pulled By: yf225

fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 module: autograd Related to torch.autograd, and the autograd engine in general module: bc-breaking Related to a BC-breaking change module: cpp Related to C++ API module: cpp-extensions Related to torch.utils.cpp_extension module: cpu CPU specific problem (e.g., perf, algorithm) module: internals Related to internal abstractions in c10 and ATen module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: optimizer Related to torch.optim module: pybind Related to our Python bindings / interactions with other Python libraries module: sparse Related to torch.sparse oncall: jit Add this issue/PR to JIT oncall triage queue oncall: quantization Quantization support in PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants