Skip to content

WIP: add __torch_function__ API override mechanism#25629

Closed
rgommers wants to merge 35 commits intopytorch:masterfrom
Quansight:torch_function
Closed

WIP: add __torch_function__ API override mechanism#25629
rgommers wants to merge 35 commits intopytorch:masterfrom
Quansight:torch_function

Conversation

@rgommers
Copy link
Copy Markdown
Collaborator

@rgommers rgommers commented Sep 4, 2019

This is still draft, the Python implementation is complete, but the C++ part still needs to be added.

This mechanism allows Tensor-like objects (including Tensor subclasses) to override torch functions with their own implementations.

Closes gh-24015 (see description of that issue for more details).

For a toy example, see the DiagonalTensor class in test/test_overrides.py. The __torch_function__ method and implements decorator there are what a package with a Tensor-like class should implement. It can then override a PyTorch function with its own function decorated with @implements(torch.<funcname>).

Performance of the current Python implementation of the override mechanism is O(1-2 us) overhead for regular use (a small number of input parameters to the function). Benchmark results:

$ asv run --python=same
· Discovering benchmarks
· Running 6 total benchmarks (1 commits * 1 environments * 6 benchmarks)
[  0.00%] ·· Benchmarking existing-py_home_rgommers_anaconda3_envs_pytorch-gcc91_bin_python                                                                           
[  8.33%] ··· Running (bench_overrides.TorchFunction.time_mock_broadcast_tensors_duck--)......
[ 58.33%] ··· ....TorchFunction.time_mock_broadcast_tensors_duck           958±70ns
[ 66.67%] ··· ...TorchFunction.time_mock_broadcast_tensors_torch           962±90ns
[ 75.00%] ··· ...rrides.TorchFunction.time_mock_concatenate_duck        1.59±0.09μs
[ 83.33%] ··· ...rrides.TorchFunction.time_mock_concatenate_many           93.8±8μs
[ 91.67%] ··· ...rides.TorchFunction.time_mock_concatenate_mixed        2.66±0.02μs
[100.00%] ··· ...rides.TorchFunction.time_mock_concatenate_torch           990±10ns

As discussed in the Performance considerations section of gh-22402, the goal is that once the dispatch mechanism is moved to C++, the overhead when inputs are Tensor instances is zero, and the overhead for Tensor-like objects sub-microsecond.

This feature is inspired by and analogous to NumPy's __array_function__ protocol (see NumPy Enhancement Proposal 18).

This PR currently contains:

  • the Tensor.__torch_function__ method
  • tests for __torch_function__ behavior
  • overrides for unique, tensordot, lu, and broadcast_tensors (no complete coverage of even torch.functional, but enough to get a first idea)
  • benchmarks that measure the overhead for a range of scenarios, based on the airspeed velocity benchmarking framework
  • documentation for the override/dispatch mechanism as docstrings

Missing from the PR:

  • C++ implementation of the performance-critical parts of the dispatch-mechanism
  • C++ equivalent of torch_function_dispatch decorator to make a single function overridable
  • overrides for all public functions in the torch and torch.<public_submodule> namespaces
  • higher-level documentation for library authors and end users

prasunanand and others added 30 commits September 3, 2019 16:51
Remove utility code that we can simply import from NumPy for now.
Things import again and can be tested.
Note, it does not get called normally (there's just a check it exists),
the dispatcher calls the function implementation directly.
Manual check of current performance:
```
In [10]: %timeit mock_concatenate([Tensor(1), Tensor(2)])
2.58 µs ± 7.92 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [11]: %timeit mock_broadcast_tensors(Tensor(1))
1.65 µs ± 3.71 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
```

Run with ASV:
```
$ asv run --python=same --dry-run
· Discovering benchmarks
· Running 6 total benchmarks (1 commits * 1 environments * 6 benchmarks)
[  0.00%] ·· Benchmarking existing-py_home_rgommers_anaconda3_envs_pytorch_bin_python
[  8.33%] ··· Running (bench_overrides.TorchFunction.time_mock_broadcast_tensors_duck--)......
[ 58.33%] ··· ...orchFunction.time_mock_broadcast_tensors_duck            793±5ns
[ 66.67%] ··· ...rchFunction.time_mock_broadcast_tensors_torch           867±70ns
[ 75.00%] ··· ...ides.TorchFunction.time_mock_concatenate_duck         1.44±0.1μs
[ 83.33%] ··· ...ides.TorchFunction.time_mock_concatenate_many           86.2±7μs
[ 91.67%] ··· ...des.TorchFunction.time_mock_concatenate_mixed        2.33±0.01μs
[100.00%] ··· ...des.TorchFunction.time_mock_concatenate_torch            902±9ns
```

So performance is as expected for a pure-Python implementation.
That behavior of forwarding the sum() function to a sum() method is
specific to NumPy.
The removed test checked handling of >32 input parameters.
NumPy limits this to 32, with NPY_MAXARGS. PyTorch doesn't have that
limitation.
This way, the ASV benchmarks can be run on master. Individual benchmarks
will fail, but not the ASV run itself.  This is an ASV feature; you
can go back in time and run a benchmark suite on older commits.
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 4, 2019

cc @cpuhrsch

@rgommers
Copy link
Copy Markdown
Collaborator Author

rgommers commented Sep 4, 2019

Hi @cpuhrsch, I just read your NestedTensor RFC 0.0.2 and I guess @ezyang Cc'd you because of this Prototype Dispatch section in that document:

We define an explicit monkey_patch function within torch.nested that requires an explicit user opt-in to use this prototype. The dispatch mechanism is implemented via isinstance and used to overwrite some of the existing torch functions. The explicit opt-in is deemed necessary because a) some torch functions might slightly differ in signature b) the dispatch mechanism is very slow.

I still have to read the other NestedTensor docs and code. It'd be great to have a chat soon to make sure this __torch_function__ is useful for your design. Also I'd like to understand why you need signatures that differ, and perhaps include those as test cases here.

-----

Airspeed Velocity manages building and Python virtualenvs or conda envs by
itself, unless told otherwise (e.g. with `--python=same`).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A feature that I've always found extremely irritating about asv XD

Copy link
Copy Markdown
Collaborator Author

@rgommers rgommers Sep 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the default is quite annoying. It makes sense for CI or running it on a dedicated server, but not for typical use when you're developing. I've never complained about it, but maybe I should go open an issue now (EDIT: I did do that).

Copy link
Copy Markdown
Collaborator Author

@rgommers rgommers Sep 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now in ASV master, asv dev does this now. (airspeed-velocity/asv#872)

@@ -0,0 +1,85 @@
{
// The version of the config file format. Do not change, unless
// you know what you are doing.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MY EYEEEESS

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 4, 2019

cc @apaszke since there are some benchmark bits here

@cpuhrsch
Copy link
Copy Markdown
Contributor

cpuhrsch commented Sep 4, 2019

@rgommers: Thanks for reading the RFC!

Some torch functions that are currently implemented for torch.Tensor might require various extensions when used with NestedTensor. For example, we might want torch.narrow(input, dim, start, length) to accept tuples for dim, start, length if the input is a NestedTensor. In essence, if using a NestedTensor the semantics of an operation generalize which might slightly change the way we want to parse or check the arguments. We could have torch.nested_narrow of course, but that'll quickly yield to a massive API surface for small one-off changes.

Let's get together and talk about this in person whenever you want!

@@ -0,0 +1 @@
from __future__ import absolute_import, division, print_function
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the terminal PR, we'll probably ask you to move these benchmarks to https://github.com/pytorch/benchmark since it makes it easier to run benchmarks across versions if they live out of line.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's where they live - I was wondering why there were so few in this repo. Sounds good to move these at the end.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 4, 2019

@cpuhrsch I believe this is fine, because __torch_function__ interposes prior to argument parsing.

# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
if (arg_type not in overloaded_types and
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always find it interesting when an O(n) lookup is used over O(1). overloaded_types is probably always small so this should not be a problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(from reading below: order matters!)

# exec. This version has the advantage of giving the helper function a
# more interpretable name. Otherwise, the original function does not
# show up at all in many cases, e.g., if it's written in C++ or if the
# dispatcher gets an invalid keyword argument.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Does overriding __repr__ on the decorator type work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test that Tensor.__repr__ works as expected, and the DiagonalTensor test has a custom __repr__. And overriding:

In [1]: import torch                                                               

In [2]: torch.unique.__repr__()                                                    
Out[2]: '<function unique at 0x7fb935427378>'

In [3]: torch.unique.__repr__ = lambda : 'aahhhh'                                  

In [4]: torch.unique.__repr__()                                                    
Out[4]: 'aahhhh'

In [5]: torch.unique is torch.unique._implementation  # checking the decorator was active                                                                     
Out[5]: False


def implement_torch_function(
implementation, public_api, relevant_args, args, kwargs):
"""Implement a function with checks for __torch_function__ overrides.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though not a decorator, intruigingly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_function_dispatch is the decorator. this generates the function that that decorator returns, and is used from within torch_function_dispatch only.

# (directly or with subclasses that do not override __torch_function__).
if (not overloaded_args or types == _TENSOR_ONLY or
all(type(arg).__torch_function__ is _TORCH_FUNCTION
for arg in overloaded_args)):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I suppose the benchmarks for the C++ will eventually show up, but the short cut doesn't seem all that short-cutty to me. In particular, get_overloaded_types_and_args will treat Tensor as an "overload", so if __torch_function__ is defined on Tensor (which it is), then you will never actually be in a situation where overloaded_args is falsish. (If this code is just meant to be semantics, as opposed to code that will be directly transliterated to C++, you can disregard this comment)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this is not meant for direct translation to C++. There we want to hook in only after the check that inputs are tensor instances. In the torch.functional Python functions there's no such check though, so the inputs are checked upfront.

@rgommers
Copy link
Copy Markdown
Collaborator Author

rgommers commented Sep 5, 2019

@cpuhrsch I believe this is fine, because __torch_function__ interposes prior to argument parsing.

Indeed, there's no fundamental problem there, mismatching signatures can be handled. It sounds though like the signatures match but the semantics change, at least in the case of torch.narrow.

Some torch functions that are currently implemented for torch.Tensor might require various extensions when used with NestedTensor. For example, we might want torch.narrow(input, dim, start, length) to accept tuples for dim, start, length if the input is a NestedTensor. In essence, if using a NestedTensor the semantics of an operation generalize which might slightly change the way we want to parse or check the arguments. We could have torch.nested_narrow of course, but that'll quickly yield to a massive API surface for small one-off changes.

Generalized semantics sounds like a good thing. As long as the semantics for inputs that also work with torch.Tensor don't change to the extent that correctness of code that works with both isn't guaranteed (probably obvious, but I won't forget the damage that numpy.matrix has done easily ...).

Let's get together and talk about this in person whenever you want!

Are you on the PyTorch Slack? I'm Ralf Gommers there. Or otherwise [email protected].

@rgommers
Copy link
Copy Markdown
Collaborator Author

rgommers commented Sep 5, 2019

The caffe2-py2-devtoolset7-rocmrpm-centos7.5-test CI failure is due to that config using Python < 2.7.9, and exec having a bug there (https://bugs.python.org/issue21591). I have a fix or workaround for that one.

The other failure is:

FAIL: test_unique (__main__.TestOperators)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/onnx/test_operators.py", line 729, in test_unique
    opset_version=11)
  File "test/onnx/test_operators.py", line 67, in assertONNX
    self.assertExpected(onnx_model_pbtxt, subname)
  File "/home/rgommers/code/pytorch/test/common_utils.py", line 901, in assertExpected
    self.assertMultiLineEqual(expected, s)
AssertionError: 'ir_v[88 chars]ut: "x"\n    output: "1"\n    output: "2"\n   [1164 chars]n}\n' != 'ir_v[88 chars]ut: "0"\n    output: "1"\n    output: "2"\n   [1164 chars]n}\n'
  ir_version: 4
  producer_name: "pytorch"
  producer_version: "1.2"
  graph {
    node {
-     input: "x"
?             ^
+     input: "0"
?             ^
      output: "1"
      output: "2"
      output: "3"
      output: "4"
      op_type: "Unique"
      attribute {
        name: "axis"
        i: 0
        type: INT
      }
      attribute {
        name: "sorted"
        i: 1
        type: INT
      }
    }
    name: "torch-jit-export"
    input {
-     name: "x"
?            ^
+     name: "0"
?            ^
      type {
        tensor_type {

I have the impression that the new version is actually more consistent (if outputs are labelled "1", "2", etc., it makes sense for the input positional parameter to be labelled "0" rather than "x"). It's fixable by regenerating test/onnx/expect/, but I'm not sure if that's desired here?

Docstring and signature of unique are unchanged, as is introspection result:

>>> inspect.signature(torch.unique)                                                
<Signature (input, sorted=True, return_inverse=False, return_counts=False, dim=None)>

So it looks like something subtle in how the ONNX output is generated.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 5, 2019

You can just accept the new output.

"""
def decorator(implementation):
if verify:
verify_matching_signatures(implementation, dispatcher)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I would not have thought to implement this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I can't take the credit for that one though, stolen from NumPy:)

This change seems to be due to regenerating the ONNX export now
that unique() was decorated with `torch_function_dispatch`.
The same will need to be done for other expected values once we
add overrides to them.
Most py27 CI builds passed, but one failed with:
```
SyntaxError: unqualified exec is not allowed in function 'decorator' it
is a nested function (_overrides.py, line 231)
```

This is https://bugs.python.org/issue21591, which was fixed in
Python 2.7.9, looks like
`caffe2-py2-devtoolset7-rocmrpm-centos7.5-test` uses an older version.
@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Sep 5, 2019
@rgommers
Copy link
Copy Markdown
Collaborator Author

@pytorchbot rebase this please

@pytorchbot
Copy link
Copy Markdown
Collaborator

Sorry, I can't merge this because there are conflicts. To merge this yourself, run the commands below:

git fetch origin master
git fetch [email protected]:Quansight/pytorch.git torch_function
git checkout FETCH_HEAD
git merge origin/master
git push [email protected]:Quansight/pytorch.git HEAD:torch_function

(To learn more about this bot, see Bot commands.)

@jph00
Copy link
Copy Markdown

jph00 commented Sep 26, 2019

@rgommers we're very interested in this work for stuff we're writing for fastai v2 - do you have a sense of how far this is from being merged and available in nightlies?

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 27, 2019

This branch is not merged yet, so it's not available in nightlies.

@rgommers
Copy link
Copy Markdown
Collaborator Author

@jph00 moving this into C++ is getting there; it's taken a little longer than expected because touching the central signature/argument parsing and codegen is tricky. But I think we have something that works now for all functions that go through the tools/autograd/gen_python_functions.py machinery (which is most of what we need).

Timeline wise I hope to update this PR next week and have it merged within 3-4 weeks.

we're very interested in this work for stuff we're writing for fastai v2

That sounds very interesting. I'd love to make sure that what we do covers your needs for fastai v2 and is in time. I'll comment on gh-22402 in more detail.

@rgommers
Copy link
Copy Markdown
Collaborator Author

Continued in gh-27064, which is ready for review/testing. So closing this PR.

@rgommers rgommers closed this Oct 24, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 module: docs Related to our documentation, both in docs/ and docblocks module: onnx Related to torch.onnx open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement __torch_function__ to let Tensor-like objects override torch functions

6 participants