RFC-0001: Add method __torch_function__ RFC.#3
Conversation
d0ca416 to
235eb50
Compare
235eb50 to
0e33cf2
Compare
|
I would have expected a discussion of backwards compatibility on this proposal, both with |
RFC 0001.md
Outdated
|
|
||
| ```python | ||
| class SubTensor(torch.Tensor): | ||
| def __torch_tensor__(self, func, types, args, kwargs): |
There was a problem hiding this comment.
You don't say what the old type of __torch_tensor__ was so I can't tell what the difference is.
@ngoldbaum do you remember why we didn't line up the type exactly with Numpy's type in the beginning?
There was a problem hiding this comment.
IIRC that was me. Not 100% sure, but I believe I removed (or never added) types because it was extra complexity to handle subclasses, and Tensor subclasses aren't really a thing today.
There was a problem hiding this comment.
No reason afaik. That API got hammered out before I started working on the feature so I don’t know why the API is different. It wouldn’t be terribly hard to add types to the signature if we wanted to do that.
There was a problem hiding this comment.
I think it's safe to say we should do this sooner rather than later, even if other parts of this RFC change in their design.
There was a problem hiding this comment.
As I wrote over in pytorch/pytorch#30730 (comment), my motivation for the types argument in NumPy's __array_function__ was never about subclasses (which I agree are generally awful). My concern was making it as easy and idiomatic as possible for implementers of __array_function__ to defer to unrecognized types that might also implement an operation. This is generally a good thing for the ecosystem, but people writing special methods tend not to bother, at least for builtin numeric protocols like __add__ :).
There was a problem hiding this comment.
was never about subclasses (which I agree are generally awful)
I kind of disagree. numpy.ndarray subclasses give us a bad taste in the mouth because of the history of numpy.matrix and other badly written subclasses. But in principle they make sense. Also for PyTorch there's a lot of interest, and the use cases are solid - and letting those users write a whole torch.Tensor-like object is definitely not feasible in most cases.
|
(although doing so would break any downstream users who are out there) |
|
Procedural note: let's put a more descriptive filename on RFCs, so we can easily tell what's what :) |
RFC 0001.md
Outdated
| PyTorch `master` pointed to commit hash `957a07ffbd13d8a805f4d718e0282efc5d2bff85` at the time of writing. Any classes implementing `__torch_function__` based on the usage in this commit hash will break completely, due to the differing signature of the protocol. However, as a release hasn't been made with `__torch_function__` in it, this is a minor-impact issue. | ||
|
|
||
| ### With NumPy | ||
| As we are using a different protocol compared to NumPy `__torch_function__` vs `__array_function__`, there is no difference to the usage for those using NumPy. We propose to delay the issue of allowing the usage of Torch tensors with NumPy functions to a separate RFC. |
There was a problem hiding this comment.
Sorry, I wasn't clear in my earlier comment. My question is not allowing PyTorch tensors be used with Numpy functions (although this is an interesting question to pose), but if I am a Numpy user who is familiar with the __array_function__ API, and I come to PyTorch expecting __torch_function__ to work the same way, will my expectations be surprised in any way?
There was a problem hiding this comment.
I don't think there are more than a few handful of users that have built up an intuition here since it's so new, nor do I think including methods will be particularly surprising.
b41e525 to
83d42a6
Compare
RFC 0001 — `__torch_function__` for methods of the `torch.Tensor` class.md
Outdated
Show resolved
Hide resolved
b3536f7 to
9de4faa
Compare
|
Umm, the new filename is better, but can we avoid putting hard to quote characters in the filename? :> Preferably no spaces either. |
d3d8f15 to
fa73d2b
Compare
fa73d2b to
686e4cf
Compare
|
This is very interesting, thanks for the RFC. Am I right to assume this could go a long way in helping enable something like a |
|
@Balandat I cannot comment more broadly, but certainly, the code example you show in the issue description should be possible, yes. With my understanding of the problem it should all be possible. |
As discussed in pytorch/pytorch#34369.
Summary: This PR adds the `types` argument to `__torch_function__` as per RFC 0001: pytorch/rfcs#3 Pull Request resolved: #34303 Differential Revision: D20474992 Pulled By: ezyang fbshipit-source-id: cdd40b3b38f3bda4ece8812a629f5db87e919d01
Summary: This is according to pytorch/rfcs#3. Pull Request resolved: #34369 Differential Revision: D20963929 Pulled By: ezyang fbshipit-source-id: e618af6fd36e1dfaeda617162314ad5840f55358
Summary: This is according to pytorch/rfcs#3. Pull Request resolved: pytorch#34369 Differential Revision: D20963929 Pulled By: ezyang fbshipit-source-id: e618af6fd36e1dfaeda617162314ad5840f55358
Summary: According to pytorch/rfcs#3 From the goals in the RFC: 1. Support subclassing `torch.Tensor` in Python (done here) 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here) 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` subclasses (done in #30730) 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here) 5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. (done here) 6. Preserve subclass attributes when using methods or views/slices/indexing. (done here) 7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). (done here) 8. The ability to give external libraries a way to also define functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR) This PR makes the following changes: 1. Adds the `self` argument to the arg parser. 2. Dispatches on `self` as well if `self` is not `nullptr`. 3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`. 4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`. 5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`. TODO: - [x] Sequence Methods - [x] Docs - [x] Tests Closes #28361 Benchmarks in #37091 (comment) Pull Request resolved: #37091 Reviewed By: ngimel Differential Revision: D22765678 Pulled By: ezyang fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
|
@hameerabbasi now that the code landed, can you make final updates to this RFC to adjust to reality? Then let's merge it. |
* Document that __torch_function__ may get methods passed to it even for non-subclasses. * Document __getattr__ idiom. * Add the double inheritance hierarchy diagram to the docs. * Explain how to have a fallback route for things that you don’t explicitly override for subclasses. * Explain how to override single methods vs have a global hook.
|
Okay, let's finally get this in! Thanks @hameerabbasi and @ezyang |
|
Has this RFC been implemented? Related: pytorch/pytorch#52265 |
|
Yup! |
No description provided.