Skip to content

Conversation

@eternalNight
Copy link
Contributor

With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast.

To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers.

This corresponds to issue 1 of #7577.

Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with deepspeed --num_gpus=N this_file.py -c -p -m 23 to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB.

Profiles of a single step before this PR:

image image

After this PR:

image image

This PR also reduces peak memory usage because the fast_free_schedule() today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage.

P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed.

@eternalNight eternalNight force-pushed the eternalNight/fuse_allgather_and_autocast branch from 771a833 to 9701899 Compare September 24, 2025 01:53
@tohtana
Copy link
Collaborator

tohtana commented Sep 25, 2025

@eternalNight,

Thank you for submitting a great PR! I believe this will bring a significant improvement across many use cases.

I’m happy to merge this change, but I’d like your feedback on one point:
Right now, the approach is to pass the downcasted dtype to allgather. As an alternative, would it be possible to identify a downcast-taking parameter and place allgather after the downcast?

The advantage of this approach is that it keeps the C++ side (the more primitive components in DeepCompile) simpler. That said, there may be trade-offs compared to your current design.

I’d really appreciate your thoughts on this.

@eternalNight
Copy link
Contributor Author

@eternalNight,

Thank you for submitting a great PR! I believe this will bring a significant improvement across many use cases.

I’m happy to merge this change, but I’d like your feedback on one point: Right now, the approach is to pass the downcasted dtype to allgather. As an alternative, would it be possible to identify a downcast-taking parameter and place allgather after the downcast?

The advantage of this approach is that it keeps the C++ side (the more primitive components in DeepCompile) simpler. That said, there may be trade-offs compared to your current design.

Adding allgather after the param downcast is possibly feasible. The major differences from this version are:

  1. We don't need the additional dtype argument in allgather_param().
  2. We still need the dtype argument in executor->allgatherParam because the output_buf needs that dtype to determine its size.
  3. Similar to prefetch_params_fused() and executor->prefetchParamsFused(). But additionally we need to move the downcast ops ahead of the prefetch one.
  4. Type-casting is unnecessary in launchAllGather().
  5. Either the downcast op should run on the ag stream (can we move an existing op to a separate stream?), or we need an additional event to ensure that allgather does not start till the downcast completes.

On the C++ side, 1-4 reduces complexity (a bit) while 5 increases complexity. So I'm not sure if either approach is really simpler than the other.

I’d really appreciate your thoughts on this.

@eternalNight
Copy link
Contributor Author

BTW, I found today that, with the inductor backend, loss of that openvla-like test became NaN after a few iterations, and that was so even without torch autocast. The same setup works fine with the eager backend.

I suspect that it's related to tensor lifecycle, but need some more time to find out.

@eternalNight eternalNight force-pushed the eternalNight/fuse_allgather_and_autocast branch 2 times, most recently from d8e0d1a to b5a3aba Compare September 26, 2025 01:24
@tohtana
Copy link
Collaborator

tohtana commented Sep 26, 2025

@eternalNight Thanks for sharing your insights! I was considering whether it might be useful to implement various features as a graph modification, but I think your approach makes sense for this PR.

BTW, I found today that, with the inductor backend, loss of that openvla-like test became NaN after a few iterations, and that was so even without torch autocast. The same setup works fine with the eager backend.

Did you encounter this issue without the changes in this PR? I’ve run into similar problems while I was writing the code of DeepCompile. In most cases, Triton code generated by Inductor releases buffer tensors earlier than expected, causing the tensor to be destroyed.
I tested this with several models, including LLaMA and Mixtral, but it’s possible I missed a pattern that shows up in a different architecture.

@eternalNight
Copy link
Contributor Author

BTW, I found today that, with the inductor backend, loss of that openvla-like test became NaN after a few iterations, and that was so even without torch autocast. The same setup works fine with the eager backend.

Did you encounter this issue without the changes in this PR? I’ve run into similar problems while I was writing the code of DeepCompile. In most cases, Triton code generated by Inductor releases buffer tensors earlier than expected, causing the tensor to be destroyed. I tested this with several models, including LLaMA and Mixtral, but it’s possible I missed a pattern that shows up in a different architecture.

Yes, the issue exists before this PR and without torch autocast.

@eternalNight
Copy link
Contributor Author

eternalNight commented Sep 26, 2025

@tohtana I think I have found (at least one of) the cause of that loss NaN issue.

Torch op schema can specify if the output tensor aliases the storage of an input one. Take t() as an example: aten::t(Tensor(a) self) -> Tensor(a).

Inductor-generated code explicitly drop a tensor when it's no longer needed. So for allgather_param, it looks like below (note the del statements):

buf0 = torch.ops.dc.allgather_param.default(primals_53, 139956546525584, 49)
del primals_53
buf1 = buf0
del buf0

Now for release_param(), it is declared as release_param(Tensor a, int graph_id, int id, int n_users) -> Tensor, but the returned tensor is just a, not a copy of a. With the del that follows, a can be freed early. Changing the schema to release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a) makes the situation better but only a bit: losses are back to normal only when I commented out fast_free_schedule in zero3_compile.py. There probably is yet another cause.

Anyway, this is a different issue from what this PR is aimed at. I'll open a separate issue for this.

Update: That "another cause" is the previous workaround, i.e., cloning the input tensor in release_param. After cleaning that up, the loss NaN issue disappears. Refer to #7597 for the full code change. This PR has been tested to be compatible with that one.

@eternalNight eternalNight changed the title deepcompile: Fuse allgather and downcast DeepCompile: Fuse allgather and downcast Sep 26, 2025
Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this after #7597 is merged. Please let me know when you are okay for merge (I think this PR should have some conflicts with #7597)

With autocast enabled, a majority of weights are downcasted before being
used in calculations. Today zero3_compile gathers the FP32 weights
before they are downcasted. That is sub-optimal because FP32 weights
consumes more bandwidth to allgather and takes more time to downcast.

To reduce communication and downcast time, fuse allgather and downcast
in the dc ops. The target type is now passed to allgather_param() and
prefetch_params_fused() which will downcast the (partial) weights before
launching allgathers.

Signed-off-by: Junjie Mao <[email protected]>
@eternalNight eternalNight force-pushed the eternalNight/fuse_allgather_and_autocast branch from b5a3aba to aeaac36 Compare September 29, 2025 02:48
@eternalNight
Copy link
Contributor Author

Let's merge this after #7597 is merged. Please let me know when you are okay for merge (I think this PR should have some conflicts with #7597)

@tohtana Rebased and conflicts resolved.

@tohtana tohtana enabled auto-merge (squash) September 29, 2025 02:53
@tohtana tohtana merged commit 4efd7ec into deepspeedai:master Sep 29, 2025
12 checks passed
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
With autocast enabled, a majority of weights are downcasted before being
used in calculations. Today zero3_compile gathers the FP32 weights
before they are downcasted. That is sub-optimal because FP32 weights
consumes more bandwidth to allgather and takes more time to downcast.

To reduce communication and downcast time, fuse allgather and downcast
in the dc ops. The target type is now passed to allgather_param() and
prefetch_params_fused() which will downcast the (partial) weights before
launching allgathers.

This corresponds to issue 1 of deepspeedai#7577.

Tested with
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3
(run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect
torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3,
LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited
inter-GPU bandwidth), time per step decreases from 438ms to 337ms and
peak GPU memory usage from 9.5GB to 8.5GB.

Profiles of a single step before this PR:

<img width="1235" height="1029" alt="image"
src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054"
/>

<img width="1466" height="616" alt="image"
src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663"
/>

After this PR:

<img width="1218" height="1006" alt="image"
src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691"
/>

<img width="1537" height="559" alt="image"
src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b"
/>

This PR also reduces peak memory usage because the
`fast_free_schedule()` today always arranges param allgathers and
downcasts at the beginning of the graph. While the original FP32 params
can be freed early, all FP16/BF16-casted params are kept in GPU memory
at the beginning of the backward graph, leading to a higher peak in
memory usage.

P.S. Probably due to organization branch rule settings, I don't find
anywhere to allow reviewers to modify the branch. So I'll update the
branch per reviewers' comments and rebase if needed.

Signed-off-by: Junjie Mao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants