Skip to content
8 changes: 8 additions & 0 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,14 @@ Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Ten
}
}

Tensor FunctionalInverses::view_as_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & other) {
if (inverse_return_mode != InverseReturnMode::NeverView) {
return mutated_view.view_as(base);
} else {
return mutated_view.view_as(base).clone();
}
}

Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
// slice_inverse() inverse is just slice()
if (inverse_return_mode == InverseReturnMode::NeverView) {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6475,6 +6475,9 @@
variants: method
device_check: NoCheck
device_guard: False
dispatch:
CompositeImplicitAutograd: view_as
NestedTensorCPU, NestedTensorCUDA: view_as_nested

- func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,4 +1090,9 @@ Tensor cat_nested(const ITensorListRef& tensors, int64_t dim) {
return cat_nested_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized));
}

Tensor view_as_nested(const Tensor& self, const Tensor& other) {
TORCH_INTERNAL_ASSERT(false, "view_as(): only implemented for jagged layout nested tensors");
return Tensor();
}

} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/templates/FunctionalInverses.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct FunctionalInverses {
// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
static at::Tensor view_as_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & other);

};
}
Expand Down
10 changes: 10 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,16 @@
self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
result: at::view_as_complex(self_t)

- name: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
dispatch:
Default:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this one at all ;)

# the default case will use the CompositeImplicitAutograd impl
self: not_implemented("view_as")
other: non_differentiable
AutogradNestedTensor:
self: grad.view_as(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will safe the full self? You most likely want only .size() or some lightweight thing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's right, sadly. this PR is an inefficient workaround for our current lack of factory function support with shapes that have nested ints. I'm not sure of another way to address this without that support

other: non_differentiable

- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
condition: non_differentiable
self: where(condition, grad, 0)
Expand Down
19 changes: 19 additions & 0 deletions torch/nested/_internal/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,25 @@ def get_inner_size(inner_idx):
return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))


@register_jagged_func([torch.ops.aten.view_as.default], "self: jt, other: jt")
def view_as_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)

inp = new_kwargs.pop("input")
other = new_kwargs.pop("other")

error_message = f"view_as(): Cannot view NJT of shape {inp.shape} as shape {other.shape}"

# verify input is viewable as other's shape
if inp._ragged_idx != other._ragged_idx:
raise RuntimeError(error_message)
torch._assert_async(torch.all(inp._offsets == other._offsets), error_message)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should compare CPU offsets if they're available


return NestedTensor(func(inp._values, other._values), **extract_kwargs(other))


@register_jagged_func(
torch.ops.aten.native_layer_norm.default,
"input: jt_all, normalized_shape: any, weight: any?, bias: any?, eps: any",
Expand Down
Loading