Skip to content

Backward on nested jagged tensor vector multiplication throws error on 0 dim (but not last dim) #132695

@mahyarkoy

Description

@mahyarkoy

🐛 Describe the bug

Backward on multiplication of a nested jagged tensor n of shape (2, j1, 4) with tensor s of shape (4,) succeeds, but not with tensor w of shape (2, 1, 1). This seems unintuitive since the forward works fine in both cases, and these two cases are fundamentally the same thing (in the sense that the both s and w semantically will be broadcasted to the size of the jagged tensor) -- see example below.

A second minor but possibly related issue: I don't understand how come ns.sum() says not implemented, but reducing through ns.backward (as shown below) works fine? Is there perhaps a recommended way to reduce nested jagged tensors?

import torch
import torch.nested as nested

x = (torch.arange(20).reshape((5,4)) * 1.).requires_grad_()
y = (torch.arange(12).reshape((3,4)) * 10.).requires_grad_()
n = nested.as_nested_tensor([x, y], layout=torch.jagged)
w = torch.tensor([10., 100.], dtype=torch.float).reshape(2,1,1).requires_grad_()
s = torch.tensor([1., 10., 100., 1000.], dtype=torch.float).requires_grad_()

ns = n * s
nw = n * w
# ns.backward(gradient=torch.ones_like(ns)) ### Works fine
nw.backward(gradient=torch.ones_like(nw)) ### Throws the error mentioned below

The error from nw.backward:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[10], [line 1](vscode-notebook-cell:?execution_count=10&line=1)
----> [1](vscode-notebook-cell:?execution_count=10&line=1) nw.backward(gradient=torch.ones_like(nw))

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    [469](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:469) r"""Computes the gradient of current tensor wrt graph leaves.
    [470](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:470) 
    [471](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:471) The graph is differentiated using the chain rule. If the tensor is
   (...)
    [510](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:510)         used to compute the :attr:`tensors`.
    [511](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:511) """
    [512](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:512) if has_torch_function_unary(self):
--> [513](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513)     return handle_torch_function(
    [514](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:514)         Tensor.backward,
    [515](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:515)         (self,),
    [516](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:516)         self,
    [517](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:517)         gradient=gradient,
    [518](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:518)         retain_graph=retain_graph,
    [519](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:519)         create_graph=create_graph,
    [520](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:520)         inputs=inputs,
    [521](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:521)     )
    [522](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522) torch.autograd.backward(
    [523](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:523)     self, gradient, retain_graph, create_graph, inputs=inputs
    [524](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:524) )

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1738, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   [1730](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1730)     warnings.warn(
   [1731](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1731)         "Defining your `__torch_function__ as a plain method is deprecated and "
   [1732](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1732)         "will be an error in future, please define it as a classmethod.",
   [1733](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1733)         DeprecationWarning,
   [1734](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1734)     )
   [1736](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1736) # Use `public_api` instead of `implementation` so __torch_function__
   [1737](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1737) # implementations can do equality/identity comparisons.
-> [1738](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1738) result = torch_func_method(public_api, types, args, kwargs)
   [1740](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1740) if result is not NotImplemented:
   [1741](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/overrides.py:1741)     return result

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:302, in NestedTensor.__torch_function__(cls, func, types, args, kwargs)
    [300](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:300)     pass
    [301](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:301) with torch._C.DisableTorchFunctionSubclass():
--> [302](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:302)     return func(*args, **kwargs)

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    [512](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:512) if has_torch_function_unary(self):
    [513](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:513)     return handle_torch_function(
    [514](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:514)         Tensor.backward,
    [515](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:515)         (self,),
   (...)
    [520](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:520)         inputs=inputs,
    [521](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:521)     )
--> [522](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:522) torch.autograd.backward(
    [523](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:523)     self, gradient, retain_graph, create_graph, inputs=inputs
    [524](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/_tensor.py:524) )

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:346, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    [341](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:341)     retain_graph = create_graph
    [343](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:343) # The reason we repeat the same comment below is that
    [344](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:344) # some Python versions print out the first line of a multi-line function
    [345](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:345) # calls in the traceback and some print out the last line
--> [346](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:346) _engine_run_backward(
    [347](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:347)     tensors,
    [348](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:348)     grad_tensors_,
    [349](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:349)     retain_graph,
    [350](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:350)     create_graph,
    [351](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:351)     inputs,
    [352](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:352)     allow_unreachable=True,
    [353](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:353)     accumulate_grad=True,
    [354](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/__init__.py:354) )

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:812, in _engine_run_backward(t_outputs, *args, **kwargs)
    [810](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:810)     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    [811](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:811) try:
--> [812](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:812)     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    [813](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:813)         t_outputs, *args, **kwargs
    [814](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:814)     )  # Calls into the C++ engine to run the backward pass
    [815](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:815) finally:
    [816](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/autograd/graph.py:816)     if attach_logging_hooks:

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:286, in NestedTensor.__torch_dispatch__(cls, func, types, args, kwargs)
    [284](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:284) fn = lookup_jagged(func, *args, **kwargs)
    [285](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:285) if fn is not None:
--> [286](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:286)     return fn(*args, **kwargs)
    [288](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py:288) raise NotImplementedError(func)

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:188, in register_func.<locals>.wrapper.<locals>.get_inner.<locals>.inner(*args, **kwargs)
    [186](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:186) def inner(*args, **kwargs):
    [187](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:187)     check_schema(schema_str, func, *args, **kwargs)
--> [188](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:188)     return func(aten_op, *args, **kwargs)

File ~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:990, in sum_dim_IntList(func, *args, **kwargs)
    [986](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:986) else:
    [987](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:987)     if (
    [988](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:988)         reduce_on_non_batch
    [989](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:989)     ):  # invalid reduction cases: (ragged, non-batch), etc.
--> [990](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:990)         raise RuntimeError(
    [991](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:991)             "sum(): not supported along a ragged and non-batch dimension for NestedTensor"
    [992](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:992)         )
    [993](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:993)     # reduction cases: (ragged)
    [994](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:994)     values_ragged_dim_outer = inp._values.permute(
    [995](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:995)         inp._ragged_idx - 1,  # outer dimension
    [996](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:996)         *range(0, inp._ragged_idx - 1),
    [997](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:997)         *range(inp._ragged_idx, inp.dim() - 1),
    [998](https://vscode-remote+ssh-002dremote-002beclair.vscode-resource.vscode-cdn.net/nas/home/mkhayat/projects/sparse_gs/~/miniconda3/envs/sgspy310n/lib/python3.10/site-packages/torch/nested/_internal/ops.py:998)     )  # shift reduction dimension of values backward to outer dimension

RuntimeError: sum(): not supported along a ragged and non-batch dimension for NestedTensor

Versions

Latest nightly pytorch from conda:

libjpeg-turbo             2.0.0                h9bf148f_0    pytorch-nightly
pytorch                   2.5.0.dev20240803 py3.10_cuda12.1_cudnn9.1.0_0    pytorch-nightly
pytorch-cuda              12.1                 ha16c6d3_6    pytorch-nightly
pytorch-mutex             1.0                        cuda    pytorch-nightly
torchaudio                2.4.0.dev20240803     py310_cu121    pytorch-nightly
torchtriton               3.0.0+dedb7bdf33           py310    pytorch-nightly
torchvision               0.20.0.dev20240803     py310_cu121    pytorch-nightly

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions