-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Implement embedding_dense_backward for nested jagged tensors. We can land this in torch/nested/_internal/ops.py using the register_jagged_func API.
Alternatives
No response
Additional context
Here's how I have this partially implemented in my project:
# jagged.py
import torch
from torch.nested._internal.ops import check_ragged_dim_same, normalize_function, register_jagged_func
@register_jagged_func(
torch.ops.aten.embedding_dense_backward.default,
"self: jt, grad_output: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
)
def embedding_dense_backward(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
indices: NestedTensor = new_kwargs.pop("indices")
num_weights: int = new_kwargs.pop("num_weights")
grad_output: NestedTensor = new_kwargs.pop("grad_output")
if new_kwargs["padding_idx"] != -1 or new_kwargs["scale_grad_by_freq"]:
raise NotImplementedError("Haven't done this yet")
check_ragged_dim_same(func, indices, "self", grad_output, "grad_output")
out = torch.zeros(num_weights, grad_output.size(-1))
src = grad_output._values
indices = indices._values.long().unsqueeze(1).expand(-1, src.size(1))
out.scatter_add_(dim=0, index=indices, src=src)
return outAnd here's the unit test I wrote for this:
# test_jagged.py
import torch
from torch.nested._internal.nested_tensor import jagged_from_list
def test_Embedding_parity() -> None:
import .jagged # noqa: F401
expected = torch.tensor([3, 1, 0, 0, 0], dtype=torch.get_default_dtype()).unsqueeze(1).repeat(1, 6)
emb = torch.nn.Embedding(5, 6)
ix = torch.tensor([0, 0, 0, 1], dtype=torch.int32)
emb(ix).sum().backward()
g1 = emb.weight.grad
emb.weight.grad = None
ix, _ = jagged_from_list([ix], offsets=None, dtype=torch.int32)
emb(ix).values().sum().backward()
g2 = emb.weight.grad
torch.testing.assert_close(g1, expected)
torch.testing.assert_close(g2, expected)
torch.testing.assert_close(g1, g2)cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
davidberard98
Metadata
Metadata
Assignees
Labels
module: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module