-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Currently, NJT is restricted to representing shapes of the form (B, J*, D_0, ..., D_N) where only J* is allowed to be ragged. For this shape, it is always the case that J* is ragged wrt the batch dim B, and there is no ambiguity about this. On the other hand, consider the shape (B, D, J*, E_0, ..., E_N). Is J* ragged wrt B only or both B and D, for a different ragged value per (B, D)? It is ambiguous from this information alone.
As an exception to the restriction above, note that NJT today provides barebones transpose() support on the ragged dimension to fit with SDPA's API, where inputs are commonly of shape (B, num_heads, seq*, embedding_dim) and seq* is ragged. For this case, we know seq* should be ragged with respect to B. So this establishes a precedent for the above example where J* is ragged wrt B only.
Q: Is it possible to expand the supported set of the shapes for the jagged dimension in an incremental way? i.e.
- Allow for a single ragged dim to be wrt
Bonly and appear anywhere in the shape, continuing the precedent set by our minimaltranspose()support on the ragged dim - Later on, allow for the single ragged dim to be wrt all previous dims
I think this sort of incremental support is possible, with the ambiguity being resolved by considering the cardinality of the metadata associated with the nested int. For example:
(B, D, j0)withj0wrtBonly: associated offsets are of shape B + 1, associated lengths are of shape B(B, D, j0)withj0wrtBandD: associated offsets are of shape B*D + 1, associated lengths are of shape B*D
That is, we can maintain some metadata (or query size info for the associated offsets / lengths) that resolve ambiguity on how to interpret a given shape.
This issue takes the opinionated stance that (1) should be adopted, booting the more flexible (2) for now, as there are existing kernels within e.g. fbgemm that operate on shapes like those from (1). That is, we should relax the "ragged next to batch dim" restriction and allow for any single non-batch dimension to be ragged with respect to the batch dim. In the details, this looks like:
- Defining a ragged dim to always be ragged with respect to the batch dimension, no matter where it appears in the shape
- Relaxing the shape checks during NJT construction
- Fully taking
_ragged_idxinto account throughout our op implementations
An NJT's values component is of dimension nt.dim() - 1 and contains a packed dimension of size sum(J*) i.e. the sum of the ragged lengths. For example: an NJT of shape (B, J*, D) has values component of shape (sum(J*), D), where values.shape[0] is the packed dimension.
In generalizing the shapes the NJT can represent, the location of this packed dimension within the values component should correlate directly with the NJT's _ragged_idx. For example, to represent shape (B, D, J*), the values component should be of shape (D, sum(J*)). This allows for generality across non-batch dimensions. Note that the packed dimension is always "packed over the batch" no matter where it appears. Thus, the offsets component should always be size B + 1.
This also opens us up to future work involving multiple ragged dims, where each is ragged wrt the batch dim.
cc @cpuhrsch @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ