Skip to content

Conversation

@stas00
Copy link
Collaborator

@stas00 stas00 commented Jul 9, 2025

Improved TiledMLP and SequenceTiledCompute for bs>1

This PR:

  • extends the testing utils to add CaptureStd*, CaptureLogger context managers
  • extends the test to run both bs=1 and bs=2
  • use an uneven seqlen to test varlen shards
  • flattens bs+seqlen dim, to avoid problems with grad tensor strides when bs>1 - mlp doesn't care for the bs dimension so using a pretend bs*seqlen seqlen instead and restoring the shape at the end for the grad.

@stas00 stas00 changed the title Stas/tiled mlp bs gt 1 TiledMLP: improve the bs>1 use-case Jul 9, 2025
stas00 added 3 commits July 9, 2025 01:15
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
Signed-off-by: Stas Bekman <[email protected]>
@stas00 stas00 changed the title TiledMLP: improve the bs>1 use-case TiledMLP + SequenceTiledCompute: improve the bs>1 use-case Jul 9, 2025
@stas00 stas00 marked this pull request as ready for review July 15, 2025 19:36
@stas00 stas00 merged commit c2bb53f into master Jul 16, 2025
9 checks passed
@stas00 stas00 deleted the stas/tiled-mlp-bs-gt-1 branch July 16, 2025 16:30
@stas00
Copy link
Collaborator Author

stas00 commented Jul 16, 2025

Thank you, Logan!

lpnpcs pushed a commit to lpnpcs/DeepSpeed that referenced this pull request Jul 30, 2025
…ai#7422)

Improved TiledMLP and SequenceTiledCompute for bs>1

This PR:
- extends the testing utils to add `CaptureStd*`, `CaptureLogger`
context managers
- extends the test to run both bs=1 and bs=2
- use an uneven seqlen to test varlen shards
- flattens bs+seqlen dim, to avoid problems with grad tensor strides
when bs>1 - mlp doesn't care for the bs dimension so using a pretend
`bs*seqlen` seqlen instead and restoring the shape at the end for the
grad.

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
LYMDLUT pushed a commit to LYMDLUT/DeepSpeed that referenced this pull request Aug 20, 2025
…ai#7422)

Improved TiledMLP and SequenceTiledCompute for bs>1

This PR:
- extends the testing utils to add `CaptureStd*`, `CaptureLogger`
context managers
- extends the test to run both bs=1 and bs=2
- use an uneven seqlen to test varlen shards
- flattens bs+seqlen dim, to avoid problems with grad tensor strides
when bs>1 - mlp doesn't care for the bs dimension so using a pretend
`bs*seqlen` seqlen instead and restoring the shape at the end for the
grad.

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Signed-off-by: lym <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
…ai#7422)

Improved TiledMLP and SequenceTiledCompute for bs>1

This PR:
- extends the testing utils to add `CaptureStd*`, `CaptureLogger`
context managers
- extends the test to run both bs=1 and bs=2
- use an uneven seqlen to test varlen shards
- flattens bs+seqlen dim, to avoid problems with grad tensor strides
when bs>1 - mlp doesn't care for the bs dimension so using a pretend
`bs*seqlen` seqlen instead and restoring the shape at the end for the
grad.

---------

Signed-off-by: Stas Bekman <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants