-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support view_as() on NJT; allow nested int swapping #139196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17b85ad
50237e4
d5ebe05
9ae2192
01c6f6a
1a0d24f
a7c868e
567c5f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1932,6 +1932,16 @@ | |
| self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] | ||
| result: at::view_as_complex(self_t) | ||
|
|
||
| - name: view_as(Tensor(a) self, Tensor other) -> Tensor(a) | ||
| dispatch: | ||
| Default: | ||
| # the default case will use the CompositeImplicitAutograd impl | ||
| self: not_implemented("view_as") | ||
| other: non_differentiable | ||
| AutogradNestedTensor: | ||
| self: grad.view_as(self) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will safe the full self? You most likely want only .size() or some lightweight thing here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep that's right, sadly. this PR is an inefficient workaround for our current lack of factory function support with shapes that have nested ints. I'm not sure of another way to address this without that support |
||
| other: non_differentiable | ||
|
|
||
| - name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor | ||
| condition: non_differentiable | ||
| self: where(condition, grad, 0) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1458,6 +1458,25 @@ def get_inner_size(inner_idx): | |
| return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp)) | ||
|
|
||
|
|
||
| @register_jagged_func([torch.ops.aten.view_as.default], "self: jt, other: jt") | ||
| def view_as_default(func, *args, **kwargs): | ||
| _, new_kwargs = normalize_function( # type: ignore[misc] | ||
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | ||
| ) | ||
|
|
||
| inp = new_kwargs.pop("input") | ||
| other = new_kwargs.pop("other") | ||
|
|
||
| error_message = f"view_as(): Cannot view NJT of shape {inp.shape} as shape {other.shape}" | ||
|
|
||
| # verify input is viewable as other's shape | ||
| if inp._ragged_idx != other._ragged_idx: | ||
| raise RuntimeError(error_message) | ||
| torch._assert_async(torch.all(inp._offsets == other._offsets), error_message) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should compare CPU offsets if they're available |
||
|
|
||
| return NestedTensor(func(inp._values, other._values), **extract_kwargs(other)) | ||
|
|
||
|
|
||
| @register_jagged_func( | ||
| torch.ops.aten.native_layer_norm.default, | ||
| "input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this one at all ;)