-
Notifications
You must be signed in to change notification settings - Fork 4.7k
adding TiledFusedLogitsLoss #7437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
|
@tjruwase, could you please review? it should be easy to review since it's a fork of an existing |
| with torch.enable_grad(): | ||
| args = (self, x_shard, y_shard) | ||
| if mask is not None: | ||
| args.append(mask_shards[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stas00 this needs to be args = args + (mask_shards[i],)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Aurick - applying it here #7459
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss computation - this version pre-calculates grads in `forward`, avoiding recomputation in the backward (similar to the Liger-Kernel implementation). --------- Signed-off-by: Stas Bekman <[email protected]> Co-authored-by: Aurick Qiao <[email protected]> Signed-off-by: qimcis <[email protected]>
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss computation - this version pre-calculates grads in `forward`, avoiding recomputation in the backward (similar to the Liger-Kernel implementation). --------- Signed-off-by: Stas Bekman <[email protected]> Co-authored-by: Aurick Qiao <[email protected]> Signed-off-by: lym <[email protected]>
This PR adds `TiledFusedLogitsLoss` for an efficient fused logits+loss computation - this version pre-calculates grads in `forward`, avoiding recomputation in the backward (similar to the Liger-Kernel implementation). --------- Signed-off-by: Stas Bekman <[email protected]> Co-authored-by: Aurick Qiao <[email protected]>
This PR adds
TiledFusedLogitsLossfor an efficient fused logits+loss computation - this version pre-calculates grads inforward, avoiding recomputation in the backward (similar to the Liger-Kernel implementation).