Skip to content

Allow any single non-batch dimension to be ragged for NJT #137512

@jbschlosser

Description

@jbschlosser

Currently, NJT is restricted to representing shapes of the form (B, J*, D_0, ..., D_N) where only J* is allowed to be ragged. For this shape, it is always the case that J* is ragged wrt the batch dim B, and there is no ambiguity about this. On the other hand, consider the shape (B, D, J*, E_0, ..., E_N). Is J* ragged wrt B only or both B and D, for a different ragged value per (B, D)? It is ambiguous from this information alone.

As an exception to the restriction above, note that NJT today provides barebones transpose() support on the ragged dimension to fit with SDPA's API, where inputs are commonly of shape (B, num_heads, seq*, embedding_dim) and seq* is ragged. For this case, we know seq* should be ragged with respect to B. So this establishes a precedent for the above example where J* is ragged wrt B only.

Q: Is it possible to expand the supported set of the shapes for the jagged dimension in an incremental way? i.e.

  1. Allow for a single ragged dim to be wrt B only and appear anywhere in the shape, continuing the precedent set by our minimal transpose() support on the ragged dim
  2. Later on, allow for the single ragged dim to be wrt all previous dims

I think this sort of incremental support is possible, with the ambiguity being resolved by considering the cardinality of the metadata associated with the nested int. For example:

  • (B, D, j0) with j0 wrt B only: associated offsets are of shape B + 1, associated lengths are of shape B
  • (B, D, j0) with j0 wrt B and D: associated offsets are of shape B*D + 1, associated lengths are of shape B*D

That is, we can maintain some metadata (or query size info for the associated offsets / lengths) that resolve ambiguity on how to interpret a given shape.

This issue takes the opinionated stance that (1) should be adopted, booting the more flexible (2) for now, as there are existing kernels within e.g. fbgemm that operate on shapes like those from (1). That is, we should relax the "ragged next to batch dim" restriction and allow for any single non-batch dimension to be ragged with respect to the batch dim. In the details, this looks like:

  • Defining a ragged dim to always be ragged with respect to the batch dimension, no matter where it appears in the shape
  • Relaxing the shape checks during NJT construction
  • Fully taking _ragged_idx into account throughout our op implementations

An NJT's values component is of dimension nt.dim() - 1 and contains a packed dimension of size sum(J*) i.e. the sum of the ragged lengths. For example: an NJT of shape (B, J*, D) has values component of shape (sum(J*), D), where values.shape[0] is the packed dimension.

In generalizing the shapes the NJT can represent, the location of this packed dimension within the values component should correlate directly with the NJT's _ragged_idx. For example, to represent shape (B, D, J*), the values component should be of shape (D, sum(J*)). This allows for generality across non-batch dimensions. Note that the packed dimension is always "packed over the batch" no matter where it appears. Thus, the offsets component should always be size B + 1.

This also opens us up to future work involving multiple ragged dims, where each is ragged wrt the batch dim.

cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions