Skip to content

NJT Embedding backward #138352

@schmidt-jake

Description

@schmidt-jake

🚀 The feature, motivation and pitch

Implement embedding_dense_backward for nested jagged tensors. We can land this in torch/nested/_internal/ops.py using the register_jagged_func API.

Alternatives

No response

Additional context

Here's how I have this partially implemented in my project:

# jagged.py
import torch
from torch.nested._internal.ops import check_ragged_dim_same, normalize_function, register_jagged_func

@register_jagged_func(
    torch.ops.aten.embedding_dense_backward.default,
    "self: jt, grad_output: jt, num_weights: any, padding_idx: any, scale_grad_by_freq: any",
)
def embedding_dense_backward(func, *args, **kwargs):
    _, new_kwargs = normalize_function(  # type: ignore[misc]
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    indices: NestedTensor = new_kwargs.pop("indices")
    num_weights: int = new_kwargs.pop("num_weights")
    grad_output: NestedTensor = new_kwargs.pop("grad_output")
    if new_kwargs["padding_idx"] != -1 or new_kwargs["scale_grad_by_freq"]:
        raise NotImplementedError("Haven't done this yet")

    check_ragged_dim_same(func, indices, "self", grad_output, "grad_output")
    out = torch.zeros(num_weights, grad_output.size(-1))
    src = grad_output._values
    indices = indices._values.long().unsqueeze(1).expand(-1, src.size(1))
    out.scatter_add_(dim=0, index=indices, src=src)
    return out

And here's the unit test I wrote for this:

# test_jagged.py
import torch
from torch.nested._internal.nested_tensor import jagged_from_list

def test_Embedding_parity() -> None:
    import .jagged  # noqa: F401

    expected = torch.tensor([3, 1, 0, 0, 0], dtype=torch.get_default_dtype()).unsqueeze(1).repeat(1, 6)

    emb = torch.nn.Embedding(5, 6)
    ix = torch.tensor([0, 0, 0, 1], dtype=torch.int32)
    emb(ix).sum().backward()
    g1 = emb.weight.grad
    emb.weight.grad = None

    ix, _ = jagged_from_list([ix], offsets=None, dtype=torch.int32)
    emb(ix).values().sum().backward()
    g2 = emb.weight.grad

    torch.testing.assert_close(g1, expected)
    torch.testing.assert_close(g2, expected)
    torch.testing.assert_close(g1, g2)

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions