-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support view_as() on NJT; allow nested int swapping #139196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139196
Note: Links to docs will display an error until the docs builds have been completed. ❌ 13 New FailuresAs of commit 567c5f4 with merge base 03ec250 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
Allows for: ``` # shape (B, j1, D) njt1 = ... # shape (B, j2, D) njt2 = ... # njt1's shape: (B, j1, D) njt2.view_as(njt1) ``` so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc. [ghstack-poisoned]
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
|
||
| - name: view_as(Tensor(a) self, Tensor other) -> Tensor(a) | ||
| dispatch: | ||
| Default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this one at all ;)
| self: not_implemented("view_as") | ||
| other: non_differentiable | ||
| AutogradNestedTensor: | ||
| self: grad.view_as(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will safe the full self? You most likely want only .size() or some lightweight thing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that's right, sadly. this PR is an inefficient workaround for our current lack of factory function support with shapes that have nested ints. I'm not sure of another way to address this without that support
| # verify input is viewable as other's shape | ||
| if inp._ragged_idx != other._ragged_idx: | ||
| raise RuntimeError(error_message) | ||
| torch._assert_async(torch.all(inp._offsets == other._offsets), error_message) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should compare CPU offsets if they're available
Stack from ghstack (oldest at bottom):
Allows for:
so NJTs with different nested ints but same ragged structure can be operated on together in binary ops, cat, etc.