-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Backward on multiplication of a nested jagged tensor n of shape (2, j1, 4) with tensor s of shape (4,) succeeds, but not with tensor w of shape (2, 1, 1). This seems unintuitive since the forward works fine in both cases, and these two cases are fundamentally the same thing (in the sense that the both s and w semantically will be broadcasted to the size of the jagged tensor) -- see example below.
A second minor but possibly related issue: I don't understand how come ns.sum() says not implemented, but reducing through ns.backward (as shown below) works fine? Is there perhaps a recommended way to reduce nested jagged tensors?
import torch
import torch.nested as nested
x = (torch.arange(20).reshape((5,4)) * 1.).requires_grad_()
y = (torch.arange(12).reshape((3,4)) * 10.).requires_grad_()
n = nested.as_nested_tensor([x, y], layout=torch.jagged)
w = torch.tensor([10., 100.], dtype=torch.float).reshape(2,1,1).requires_grad_()
s = torch.tensor([1., 10., 100., 1000.], dtype=torch.float).requires_grad_()
ns = n * s
nw = n * w
# ns.backward(gradient=torch.ones_like(ns)) ### Works fine
nw.backward(gradient=torch.ones_like(nw)) ### Throws the error mentioned belowThe error from nw.backward:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], [line 1](vscode-notebook-cell:?execution_count=10&line=1)
----> [1](vscode-notebook-cell:?execution_count=10&line=1) nw.backward(gradient=torch.ones_like(nw))
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
[469](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:469) r"""Computes the gradient of current tensor wrt graph leaves.
[470](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:470)
[471](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:471) The graph is differentiated using the chain rule. If the tensor is
(...)
[510](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:510) used to compute the :attr:`tensors`.
[511](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:511) """
[512](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:512) if has_torch_function_unary(self):
--> [513](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513) return handle_torch_function(
[514](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:514) Tensor.backward,
[515](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:515) (self,),
[516](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:516) self,
[517](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:517) gradient=gradient,
[518](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:518) retain_graph=retain_graph,
[519](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:519) create_graph=create_graph,
[520](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:520) inputs=inputs,
[521](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:521) )
[522](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522) torch.autograd.backward(
[523](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:523) self, gradient, retain_graph, create_graph, inputs=inputs
[524](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:524) )
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1738, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
[1730](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1730) warnings.warn(
[1731](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1731) "Defining your `__torch_function__ as a plain method is deprecated and "
[1732](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1732) "will be an error in future, please define it as a classmethod.",
[1733](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1733) DeprecationWarning,
[1734](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1734) )
[1736](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1736) # Use `public_api` instead of `implementation` so __torch_function__
[1737](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1737) # implementations can do equality/identity comparisons.
-> [1738](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1738) result = torch_func_method(public_api, types, args, kwargs)
[1740](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1740) if result is not NotImplemented:
[1741](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1741) return result
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:302, in NestedTensor.__torch_function__(cls, func, types, args, kwargs)
[300](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:300) pass
[301](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:301) with torch._C.DisableTorchFunctionSubclass():
--> [302](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:302) return func(*args, **kwargs)
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
[512](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:512) if has_torch_function_unary(self):
[513](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513) return handle_torch_function(
[514](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:514) Tensor.backward,
[515](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:515) (self,),
(...)
[520](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:520) inputs=inputs,
[521](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:521) )
--> [522](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522) torch.autograd.backward(
[523](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:523) self, gradient, retain_graph, create_graph, inputs=inputs
[524](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:524) )
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:346, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
[341](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:341) retain_graph = create_graph
[343](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:343) # The reason we repeat the same comment below is that
[344](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:344) # some Python versions print out the first line of a multi-line function
[345](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:345) # calls in the traceback and some print out the last line
--> [346](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:346) _engine_run_backward(
[347](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:347) tensors,
[348](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:348) grad_tensors_,
[349](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:349) retain_graph,
[350](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:350) create_graph,
[351](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:351) inputs,
[352](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:352) allow_unreachable=True,
[353](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:353) accumulate_grad=True,
[354](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:354) )
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:812, in _engine_run_backward(t_outputs, *args, **kwargs)
[810](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:810) unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
[811](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:811) try:
--> [812](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:812) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[813](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:813) t_outputs, *args, **kwargs
[814](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:814) ) # Calls into the C++ engine to run the backward pass
[815](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:815) finally:
[816](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:816) if attach_logging_hooks:
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:286, in NestedTensor.__torch_dispatch__(cls, func, types, args, kwargs)
[284](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:284) fn = lookup_jagged(func, *args, **kwargs)
[285](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:285) if fn is not None:
--> [286](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:286) return fn(*args, **kwargs)
[288](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:288) raise NotImplementedError(func)
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:188, in register_func.<locals>.wrapper.<locals>.get_inner.<locals>.inner(*args, **kwargs)
[186](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:186) def inner(*args, **kwargs):
[187](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:187) check_schema(schema_str, func, *args, **kwargs)
--> [188](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:188) return func(aten_op, *args, **kwargs)
File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:990, in sum_dim_IntList(func, *args, **kwargs)
[986](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:986) else:
[987](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:987) if (
[988](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:988) reduce_on_non_batch
[989](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:989) ): # invalid reduction cases: (ragged, non-batch), etc.
--> [990](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:990) raise RuntimeError(
[991](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:991) "sum(): not supported along a ragged and non-batch dimension for NestedTensor"
[992](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:992) )
[993](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:993) # reduction cases: (ragged)
[994](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:994) values_ragged_dim_outer = inp._values.permute(
[995](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:995) inp._ragged_idx - 1, # outer dimension
[996](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:996) *range(0, inp._ragged_idx - 1),
[997](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:997) *range(inp._ragged_idx, inp.dim() - 1),
[998](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:998) ) # shift reduction dimension of values backward to outer dimension
RuntimeError: sum(): not supported along a ragged and non-batch dimension for NestedTensor
Versions
Latest nightly pytorch from conda:
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch-nightly
pytorch 2.5.0.dev20240803 py3.10_cuda12.1_cudnn9.1.0_0 pytorch-nightly
pytorch-cuda 12.1 ha16c6d3_6 pytorch-nightly
pytorch-mutex 1.0 cuda pytorch-nightly
torchaudio 2.4.0.dev20240803 py310_cu121 pytorch-nightly
torchtriton 3.0.0+dedb7bdf33 py310 pytorch-nightly
torchvision 0.20.0.dev20240803 py310_cu121 pytorch-nightly
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ