-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
PyTorch has numerous operators in its public API that need to be removed or changed in a BC-breaking way. This tracking issue provides guidance on implementing these changes and tracks deprecation tasks.
Staging deprecations
Deprecations are disruptive. When a behavior is deprecated it means that a PyTorch program will no longer function as it did, and this requires users update their programs to the latest behavior, which is often frustrating and time consuming. Deprecations are also necessary, however, to keep PyTorch modern and flexible. Staging a deprecation across multiple releases is how PyTorch tries to minimize how disruptive these changes are while staying flexible.
A typical deprecation process occurs over three releases:
- in the first release a warning is thrown the first time the deprecated behavior occurs (use TORCH_WARN_ONCE)
- the warning should tell users how to avoid the deprecated behavior, how they can adopt the new behavior (if any), and how they can continue to use the current behavior (if possible)
- PyTorch code should be updated to use the alternative behavior(s), where required; PyTorch should not throw warnings when performing non-deprecated operations
- in the second release, the deprecated behavior is disabled
- the previous warning should be updated to an error
- disabling the behavior for a release is important to prevent "silent correctness" issues, where the behavior of a program changes unexpectedly (PyTorch does not assume that warnings are sufficient to prevent this)
- if there is a new behavior, that behavior is enabled in the third release
For example, torch.div() had its behavior changed. In PyTorch 1.4 torch.div() acted like division in C++, where the result of integer division is rounded towards zero. In PyTorch 1.5 a warning was added in the documentation and the code. The warning explained the future behavior, how to use the future behavior, and how to use the current behavior in a non-deprecated way. In PyTorch 1.6 this warning became an error and in PyTorch 1.7 the behavior was changed .
Preserving the behavior of serialized torchscript
Some PyTorch models are scripted (using torchscript) and then serialized. The serialization format contains version information, so for these models it's possible to preserve the behavior they had when they were serialized by writing a "versioned symbol" or "adapter." See the note "Versioned Symbols" for details:
| // Note [Versioned Symbols] |
Returning to the deprecation of torch.div(), multiple adapters were written to preserve the behavior of PyTorch models serialized prior to PyTorch 1.6. These adapters can be found here:
| auto div_tensor = R"SCRIPT( |
Writing an adapter isn't required for every deprecation, but deprecations that impact many models may want an adapter to minimize how disruptive they are.
Current deprecations
- (floor_divide truncates instead of flooring #43874) floor_divide actually performs truncation division
- Alternatives implemented in PyTorch 1.8
- Docs and code warn in PyTorch 1.9
- TODO (PyTorch 1.10): torch.floor_divide() can be removed, needs review to see if adapter is required
- TODO (PyTorch 1.11): torch.floor_divide() to be restored (with correct behavior)
- (Deprecate torch.stft returning real-valued tensors and torch.istft accepting real-valued inputs #55948) torch.stft can produce real-valued inputs and torch.istft accepts real-valued inputs
- There is no alternative because we want people to use complex tensors, not floating point tensors mimicking complex tensors
- Docs and code warn in PyTorch 1.8, and users are required to opt-in to the deprecated behavior
- TODO (PyTorch 1.9): stop torch.stft producing real-valued input and stop torch.istft accepting real-valued inputs
- (Update linspace and logspace to throw an error when steps is not provided #55951) torch.linspace and torch.logspace don't require the user specify the steps argument
- The alternative is to specify the historic default value
- Docs and code warn in PyTorch 1.8
- TODO (PyTorch 1.9): require the steps argument
- (torch.var and torch.std are not compatible with np.var and np.std #50010) torch.std and torch.var have an "unbiased" kwarg
- TODO (PyTorch 1.9): Implement a ddof alternative and warn that unbiased is deprecated (see std/var: Deprecate overloads with "unbiased" argument #55681)
- TODO (PyTorch 1.10): Remove the unbiased argument
- (Update trapz() to be trapezoid() #52606) torch.trapezoid() is preferred to torch.trapz()
- TODO: implement torch.trapezoid() and alias torch.trapz() to it
- TODO: remove the documentation for torch.trapz() (making it a "silent alias")
- (torch.meshgrid is divergent from np.meshgrid #50276) torch.meshgrid has a different (implicit) default than np.meshgrid
- TODO: add the "indexing" kwarg to torch.meshgrid and update its documentation to clarify its behavior
- TODO: warn when the indexing kwarg is not set
- TODO (Release 2): require the indexing kwarg be set
- TODO (Release 3): change the default value of the indexing kwarg
- (torch.Tensor.repeat is divergent from np.repeat #50013) torch.repeat is actually np.tile, not np.repeat
- TODO: warn that users should use tile instead of repeat
- TODO (Release 2): remove torch.repeat
- TODO (Release 3): implement torch.repeat to be compatible with np.repeat
- (Finish deprecating torch.range #55964) torch.range is divergent from Python's range builtin
- TODO: remove torch.range
- (Finish deprecation cycle for inplace view error checks #50617) finish deprecation on view/inplace error handling
DONE (Finish deprecation cycle for inplace view error checks #56093) make all deprecation warning errors (
- (Remove deprecated codepath for old-style autograd.Function #30696) remove code and logic for old style custom autograd Function
DONE (Remove code and logic for old style custom autograd Function #57357): remove the remainder of the python and c++ code that supports old custom Function- TODO: (PyTorch 1.10) finish deprecation on instantiating custom autograd Function
- (Remove deprecated
torch.set_deterministicandtorch.is_deterministic#58096) Remove deprecatedtorch.set_deterministicandtorch.is_deterministic- TODO: (PyTorch 1.10) remove
torch.set_deterministicandtorch.is_deterministic
- TODO: (PyTorch 1.10) remove
- torch.chain_matmul is deprecated in favor of torch.linalg.multi_dot
- TODO: review with Python Array API if "multi_dot" is the best name for that operation
- TODO: (PyTorch 1.10) remove torch.chain_matmul
Remove deprecated torch.chain_matmul #70978
- torch.cholesky is deprecated in favor of torch.linalg.cholesky
- TODO: (PyTorch 1.10) remove torch.cholesky
Remove deprecated torch.cholesky #70979
- TODO: (PyTorch 1.10) remove torch.cholesky
- torch.eig is deprecated in favor of torch.linalg.eig
- TODO: (PyTorch 1.10) remove torch.eig
Remove deprecated torch.eig #70982
- TODO: (PyTorch 1.10) remove torch.eig
- torch.ger is deprecated in favor of torch.outer
- TODO: (PyTorch 1.10) remove torch.ger
- torch.lstsq is deprecated in favor of torch.linalg.lstsq
- TODO: (PyTorch 1.10) remove torch.lstsq
Remove deprecated torch.lstsq #70980
- TODO: (PyTorch 1.10) remove torch.lstsq
- torch.matrix_rank is deprecated in favor of torch.linalg.matrix_rank
- TODO: (PyTorch 1.10) remove torch.matrix_rank
Remove deprecated torch.matrix_rank #70981
- TODO: (PyTorch 1.10) remove torch.matrix_rank
- torch.norm is docs deprecated in favor of torch.linalg.vector_norm, torch.linalg.matrix_norm and torch.linalg.norm
- TODO: (PyTorch 1.10) coordinate with PyTorch/XLA and start warning when torch.norm is called
- TODO: (PyTorch 1.11) remove torch.norm
- torch.qr is deprecated in favor of torch.linalg.qr
- TODO: fix Warning with torch::nn::init::orthogonal_ with LibTorch 1.9.0 #60060
- TODO: (PyTorch 1.10) remove torch.qr
Remove deprecated torch.qr #70989
- torch.solve is deprecated in favor of torch.linalg.solve
- TODO: (PyTorch 1.10) remove torch.solve
Remove deprecated torch.solve #70986
- TODO: (PyTorch 1.10) remove torch.solve
- torch.svd is deprecated in favor of torch.linalg.svd
- TODO: (PyTorch 1.10): remove torch.svd
- torch.symeig is deprecated in favor of torch.linalg.eigh
DONE: remove torch.symeig
Remove deprecated torch.symeig #70988
- torch.matrix_power is deprecated in favor of torch.linalg.matrix_power
- TODO: ensure a warning is thrown in PyTorch 1.9
- TODO: if a warning is thrown, remove in PyTorch 1.10
- deprecate using torch.Tensor as a tensor constructor
- TODO: review current issues to develop deprecation plan
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411