-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 High-level changes:
- IMPORTANT: Both
VariableandVariable::Implare removed, andat::Tensoris always the tensor that's passed around in PyTorch, and it can record autograd history when its autograd metadata (AutogradMeta) is not null. - IMPORTANT: Autograd-related function implementations in Variable will be moved to VariableType.
- Autograd metadata now lives in an
AutogradMetastruct thatTensorImplhas a pointer to, and theAutogradMetais only populated when theat::Tensorrequires gradient. - We decide whether to dispatch to VariableType / non-VariableType functions using the
at::AutoNonVariableTypeModein appropriate places internally. (We only dispatch to VariableType functions if we need profiling/JIT-tracing/autograd) - Common Tensor functions (e.g.
numel()/sizes()/dim()) are de-virtualized in TensorImpl and have their runtime reduced by 43%-86%. tensor.is_variable()andoptions.is_variable()always return true, because everyat::Tensoris a variable (and can record autograd history when itsAutogradMetais not null). (We keepoptions.is_variable(...)for backward compatibility, and raise warning if it's set to false.)- API behavior change: changing shape/storage on
tensor.datain Python ortensor.data()in C++ will no longer updatetensor.
Pitch
Currently, the distinction between at::Tensor and Variable (subclass of at::Tensor that contains autograd metadata and functions) creates unnecessary cognitive overhead for PyTorch core development. We want to remove this distinction and make it possible to use at::Tensor everywhere in PyTorch. After merging Variable into at::Tensor, here are the common end-user APIs:
- When C++ user wants to create a non-history-recording
at::Tensorfrom anotherat::Tensor:
Current API (unchanged):
auto t = torch::ones({2, 2}, torch::requires_grad()); // t is recording history
auto t_detached = t.detach() // t_detached is the non-history-recording version of tWhen the user calls t.detach(), we do the following under the hood:
- We do the shallow copy of
t's TensorImpl, which copies the storage pointer and all other TensorImpl fields (e.g.size/stride).- Note that subclasses of TensorImpl (e.g.
SparseTensorImpl) need to know how to make a shallow copy of themselves, and we dispatch this operation to each TensorImpl subclass' ownshallow_copy_and_detach()function (by making theshallow_copy_and_detach()function virtual in TensorImpl and overriding it in TensorImpl subclasses).
- Note that subclasses of TensorImpl (e.g.
- We set the
AutogradMetapointer to NULL, to indicate that it doesn't need to record history. - We return an at::Tensor that wraps the new TensorImpl.
- When C++ user wants to enable/disable history-recording for an
at::Tensor:
Proposed API:
auto t = torch::ones({2, 2}); // t is not recording history (this already works)
t.requires_grad_(true); // t is recording history now (new API)
t.requires_grad_(false); // t is not recording history anymore (new API)When the user calls t.requires_grad_(true), we do the following under the hood:
- We initialize a struct called
AutogradMeta, which stores autograd-specific fields (such asgrad_/grad_fn_/grad_accumulator_). - We assign the struct to the
AutogradMetapointer int's TensorImpl.
When the user calls t.requires_grad_(false), we do the following under the hood:
- We set the
AutogradMetapointer int's TensorImpl to NULL.
- When C++ user wants to call non-Variable operations on an
at::Tensorwhen dispatching throughtype()
Proposed API:
{
auto t_type = t.type(); // `t_type` is a Variable type if `t` contains AutogradMeta
}
{
at::AutoNonVariableTypeMode grad_mode(false); // thread-local guard (new API)
auto non_var_type = t.type(); // "non_var_type" is a non-Variable type
}
{
at::AutoNonVariableTypeMode grad_mode(true); // thread-local guard (new API)
auto var_type = t.type(); // "var_type" is a Variable type
}Under the hood, type() checks whether the at::AutoNonVariableTypeMode thread-local guard is enabled when determining the type of the variable.
- When C++ user wants to change content of an
at::Tensorthat has AutogradMeta, without affecting the tensor'sgrad_fnorversion_counter_
Proposed behavior:
auto t = torch::ones({2, 2});
t.requires_grad_(true);
AT_ASSERT(t.current_version() == 0);
t.data().add_(1); // This is consistent with Python `.data` behavior: changing `.data` of a tensor in Python doesn't affect the tensor's `grad_fn` or `version_counter_`
AT_ASSERT(t.current_version() == 0);Motivation
-
Overly Complex OOP design: Currently the distinction between
VariableandTensoris hard to grasp:Variable::Implis a subclass of TensorImpl, but it also has anat::Tensordata member which internally wraps another TensorImpl. This co-existence of "is-a" and "has-a" relationship makes the code complicated and adds cognitive overhead. In particular, it's difficult to track which functions we have overridden inVariable::Impl, and which functions are applicable to Tensor vs. Variable (e.g.is_wrapped_number()is only valid on Tensor, not Variable) (for more context, also see note: We regret making Variable hold a Tensor). Ideally, we want to use the same tensor type everywhere in PyTorch code. -
Unused data members in
Variable::Impltake up cache/memory space: SinceVariable::Implis a subclass of TensorImpl, it contains all of the data members that a normal TensorImpl would have (such assizes_/strides_/ etc.). However, theVariable::Implfunctions always call into the underlyingat::Tensorand ignores the rest of the fields, which causes a lot of wasted cache/memory space. -
Virtual functions are slow: We care about how much time it takes to execute common Tensor functions such as
numel()/sizes()/dim(). Currently, these functions arevirtualin TensorImpl, so thatVariable::Impl(a subclass of TensorImpl) can override them and dispatch those calls to theVariable::Impl's underlyingat::Tensor. Virtual function calls are slow because they involve an extra vtable lookup. Specifically, we did the following comparison on the most common Tensor functions (all timings are in ns):
| Benchmark | Time (no flush) | Time (flush L1) | Time (flush L1+L2) | Time (flush L1+L2+L3) |
|---|---|---|---|---|
| Tensor.dim() - non-virtual | 1.3 | 3.33 | 7.6 | 58 |
| Variable.dim() - virtual | 4.5 | 24.4 | 52 | 173.67 |
| Runtime Savings | -71.11111% | -86.35246% | -85.38462% | -66.60333% |
| Tensor.numel() - non-virtual | 22.6 | 63.89 | 109.22 | 294.5 |
| Variable.numel() - virtual | 80.33 | 133.1 | 192 | 810.9 |
| Runtime Savings | -71.86605% | -51.9985% | -43.11458% | -63.68233% |
| Tensor.size(0) - non-virtual | 30.4 | 60.1 | 100.44 | 384.3 |
| Variable.size(0) - virtual | 75.4 | 127.67 | 203.8 | 875.9 |
| Runtime Savings | -59.6817% | -52.92551% | -50.71639% | -56.12513% |
| Tensor.sizes() - non-virtual | 2 | 4.25 | 13.25 | 67.6 |
| Variable.sizes() - virtual | 5.2 | 28.44 | 62.1 | 254.78 |
| Runtime Savings | -61.53846% | -85.05626% | -78.66345% | -73.46731% |
| Tensor.resize_({0}) no-op - non-virtual | 23.11 | 86.44 | 105.44 | 332.33 |
| Variable.resize_({0}) no-op - virtual | 168.4 | 254.22 | 348.56 | 890.9 |
| Runtime Savings | -86.27672% | -65.99795% | -69.74983% | -62.69727% |
| Tensor.resize_({64, 2048}) no-op - non-virtual | 33.4 | 102.56 | 129.56 | 407.22 |
| Variable.resize_({64, 2048}) no-op - virtual | 193 | 278.1 | 364.9 | 936.6 |
| Runtime Savings | -82.6943% | -63.12118% | -64.49438% | -56.52146% |
Benchmarked commit: f000101
Benchmark script: https://github.com/yf225/benchmark/blob/tensor_functions/timing/cpp2/benchmarks/aten_overheads.cpp
Non-virtual code: master...yf225:nonvirtual_tensorimpl
Virtual code: master...yf225:virtual_tensorimpl
Based on our current implementation, the runtime difference for dim(), numel(), size(), sizes(), and no-op resize() comes from the virtual function call overhead and the at::Tensor data member indirection in Variable::Impl. If we de-virtualize those functions, we would be able to cut the runtime by 43%-86% on the most common Tensor functions.
Breaking changes
Note that this change will break the current API in the following way:
In the old world, whenever we want to create a Variable that shares the same data with another Variable, we simply do auto var_new = make_variable(var.data()) or auto var_new = var.detach(), and any shape / data / storage pointer changes to var_new will be reflected in var automatically, because internally they share the same underlying at::Tensor.
However, in the new world, there is no concept of the "underlying at::Tensor" of a Variable, since the Variable itself is the Tensor. When we want to create an at::Tensor that shares the same data with another at::Tensor, we can still call auto t_new = t.detach(), but in this case, only the tensor storage data is shared (via ref-counted pointer) between t_new and t, but not the tensor size/stride information (they are copied by value). In other words, changing anything (e.g. size / stride / storage_ptr ) in the detached Tensor (t_new) that are not bits inside tensor storage won't update the original Tensor (t), and we should no longer expect those data to be shared.
This has implications for Python call sites that do
tensor.data.in_place_operation_()or
tensor_detached = tensor.detach()
tensor_detached.in_place_operation_()If in_place_operation_() only updates the data inside the tensor (such as zeros_()), such operation will still work properly; if the in-place operation changes the size, stride or the storage pointer inside the TensorImpl (e.g. resize_ / resize_as_ / set_ / transpose_), such operation on tensor.data or tensor_detached will no longer update the tensor. We will address this inconsistency in the following ways:
- Add an
allow_tensor_metadata_change_flag toTensorImplto disallow size/stride/storage_ptr changes from in-place operations such asresize_/resize_as_/set_/transpose_, and set this flag to true when people calltensor.datain Python. - Write text in the docs to actively discourage changing the shape or storage of
tensor_detachedand expectingtensorto also be updated.
Finished changes
- Add a flag to
TensorImplto disallow size/stride/storage_ptr changes from in-place operations such asresize_/resize_as_/set_/transpose_, and set this flag to true when people calltensor.datain Python. - Write text in the docs to actively discourage changing the shape or storage of
tensor_detachedand expectingtensorto also be updated. - Move
Variable::Impldata members into TensorImpl asAutogradMetastruct - Change
Variable::Implfunctions to use data members inAutogradMetastruct - Add
shallow_copy()function to each subclass of TensorImpl - Do shallow copy when the user calls
make_variable(tensor)/variable.detach()(Reason: now that autograd metadata lives in TensorImpl, in order to create a new history for for the Variable returned fromvariable.detach()we not only need to create a new AutogradMeta struct, but we also need to create a new TensorImpl object that stores pointer to the new AutogradMeta struct (which we obtain by shallow-copying the original TensorImpl). Otherwise, changing history of the detached Variable will also change the history of the original Variable, which is not the correct behavior.) - Add
AutogradMetaInterfaceclass, and makeAutogradMetaa subclass of it, so that we can makeautograd_meta_a unique_ptr in TensorImpl
- Move
set_requires_grad()/requires_grad()/grad()fromVariable::ImpltoAutogradMeta - Move
Variable::Implfunctions such asbackward()/rebase_history()/grad_accumulator()/grad_fn()out ofVariable::Impland intoAutogradMeta. - Note: we need to make these changes so that we can remove
Variable::Implclass in the next PR.
- Add thread-local guard (
at::AutoNonVariableTypeMode) to make sure that in VariableType.cpp the operations on baseType still dispatch to non-Variable type, even if the parameters are now Variables
- Make
gesv_outreturn the original input tensor instead of a new tensor (currently by copying the result tensor into the original input tensor, because a true in-placegesvis more difficult to implement. NOTE: also open an issue for this). - In VariableType.cpp, after each in-place function on the "unpacked" tensor, check pointer address equality for storage in the original input variable's TensorImpl (check this for all arguments in
unpacked_args)
- Remove
.type()calls as much as possible, to reduce the need of using theat::AutoNonVariableTypeModeguard
- Make JIT attributes
t_andts_store Variable instead of Tensor (and int_andts_use sites, don't wrap the tensor into Variable again) (global searchmake_variable(in jit/ to find places where we are doing double-wrapping fort_andts_attributes)
tril_andtriu_should not change the input tensor's TensorImpl pointer
- Move
pyobj_to TensorImpl itself, because we always need to be able to convert to and from the Python representation.
- Move
version_counter_to storage or TensorImpl, because we may capture non-requires-grad variables inside an autograd function, and we need a working version counter in these cases. - We should not share version counter in
shallow_copy_and_detach(), because a pure Tensor doesn't have concept of version counter, and it's managed by autograd instead. - We should preserve the API semantics of
tensor.datain Python, and allow it as an escape route for in-place operations without bumping version counter.
tensor.is_variable()should check whether the TensorImpl has AutogradMeta.is_variable_should be removed.
-
PR: Fix version counter sharing in Variable.set_data(...) [BC-breaking] Fix version counter sharing in set_data() #20391
-
PR: Move at::NonVariableTypeMode to TensorImpl, and check it in TensorImpl is_variable() Move at::NonVariableTypeMode to TensorImpl, and check it in is_variable() #20392
-
PR: Require passing version_counter and allow_tensor_metadata_change to shallow_copy_and_detach(): Require passing version_counter and allow_tensor_metadata_change to shallow_copy_and_detach() #20496
-
PR: Shallow-copy
indicesandvaluesin sparse tensor constructor [BC-breaking] Shallow-copy indices and values in sparse tensor ctor #20330 -
PR: Remove Variable::Impl ([BC-breaking] Remove Variable::Impl and DifferentiableViewImpl #17072)
- Remove the
at::Tensordata member (data_) fromVariable::Impl - In Variable construction and in
Variable.set_data(), copy all data fromdata.implto the variable's TensorImpl. - Make
Variable.data()the same semantics astensor.datain Python. Notice breakage in anyVariable.data()call sites - Remove the
Variable::Implclass and theDifferentiableViewImplclass - Remove mentions of
Variable::ImplandDifferentiableViewImpl - Fix comments in
[Tensor versus Variable in C++],[We regret making Variable hold a Tensor],[ Autograd View Variables ]. Go through all comments in variable.h and variable.cpp and fix any inconsistency. - NOTE: we don't need to add
SparseVariableImplthat handles how to copySparseTensorImpl, becauseSparseTensorImplalready implements theshallow_copy_and_detach()function that Variable factory functions can call. - In places where we need to ensure the tensor is not requiring gradient, we should check
!requires_grad() || at::NonVariableTypeMode::is_enabled(), instead of!requires_grad() || !at::GradMode::is_enabled(), because we don't want to moveat::GradModeto ATen.
Changes remaining:
-
Make AutogradMeta optional, so that Variable and Tensor become the same. (Tracking issue: Proposal: Optional AutogradMeta for Variable #23032)
-
Miscellaneous cleanup
- Remove
unpack()in VariableType*.cpp. - Clean up the
unpack_argslogic in gen_variable_type.py, since we are not doing unpack anymore. - Fix comments for
use_derivedin gen_variable_type.py - Remove
requires_tensor: Truein native_functions.yaml. Figure out how to fix _dimV, _dimS case (torch.randn(2, 3)._dimV()shouldn't hit that error)
-
TensorImpl de-virtualization (tracking issue: TensorImpl de-virtualization #22815)
-
Sparse invariant fix (tracking issue: In-place updating the original value tensor should also update version counter of sparse tensor's values_ tensor #22778)
-
Remove
tensor_data()API (@yf225 is working on it) -
Python / C++ Tensor API parity (@yf225 is working on it)
- Any Python Tensor API should also work on C++ Tensor, without explicit casting to Variable
- C++ API doc fix: (@yf225 is working on it)
- Remove https://pytorch.org/cppdocs/#aten section, and replace all
at::Tensorwithtorch::Tensor, and remove/fix all mentions of ATen in cpp docs and tutorials.