Skip to content

Conversation

@stas00
Copy link
Collaborator

@stas00 stas00 commented Jul 18, 2025

This PR adds TiledFusedLogitsLoss for an efficient fused logits+loss computation - this version pre-calculates grads in forward, avoiding recomputation in the backward (similar to the Liger-Kernel implementation).

Signed-off-by: Stas Bekman <[email protected]>
@stas00 stas00 marked this pull request as ready for review July 29, 2025 17:08
@stas00 stas00 enabled auto-merge (squash) July 29, 2025 18:20
@stas00
Copy link
Collaborator Author

stas00 commented Jul 30, 2025

@tjruwase, could you please review?

it should be easy to review since it's a fork of an existing SequenceTiledCompute autograd class, just specialized for logit-loss, so it's easier to understand and invoke. i.e. nothing new.

@stas00 stas00 merged commit 3292e07 into master Jul 30, 2025
9 checks passed
@stas00 stas00 deleted the stas/TiledFusedLogitsLoss branch July 30, 2025 18:15
with torch.enable_grad():
args = (self, x_shard, y_shard)
if mask is not None:
args.append(mask_shards[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 this needs to be args = args + (mask_shards[i],)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Aurick - applying it here #7459

qimcis pushed a commit to qimcis/DeepSpeed that referenced this pull request Jul 31, 2025
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss
computation - this version pre-calculates grads in `forward`, avoiding
recomputation in the backward (similar to the Liger-Kernel
implementation).

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: qimcis <[email protected]>
LYMDLUT pushed a commit to LYMDLUT/DeepSpeed that referenced this pull request Aug 20, 2025
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss
computation - this version pre-calculates grads in `forward`, avoiding
recomputation in the backward (similar to the Liger-Kernel
implementation).

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Signed-off-by: lym <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss
computation - this version pre-calculates grads in `forward`, avoiding
recomputation in the backward (similar to the Liger-Kernel
implementation).

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Aurick Qiao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants