Skip to content

Conversation

@lpnpcs
Copy link
Contributor

@lpnpcs lpnpcs commented Jun 19, 2025

I found that when using DeepSpeed Zero2 for my training task, the loss becomes 0 at the third step with a grad_norm of 1.414. This issue doesn't occur when using Zero3. I found the same issue #7188. After conducting a series of experiments, I identified the cause: there's a synchronization problem when using double ipg_buffer swapping. The issue was resolved after making modifications.

before
image

after
image

@tjruwase
Copy link
Contributor

@lpnpcs, thanks for contributing this fix. I am a bit concerned of the perf impact of synchronizing the device. Are you able to measure the perf before/after the fix. This will help guide whether to pursue finer-grained synchronization on streams instead of device.

@tjruwase
Copy link
Contributor

@lpnpcs
Copy link
Contributor Author

lpnpcs commented Jun 20, 2025

@lpnpcs, thanks for contributing this fix. I am a bit concerned of the perf impact of synchronizing the device. Are you able to measure the perf before/after the fix. This will help guide whether to pursue finer-grained synchronization on streams instead of device.

Thank you for your review. I conducted the following experiments to illustrate the impact on performance.

I trained the Qwen2.5-vl-7b model using 8 A100 GPUs with 1,000 samples for 3 epochs. Below are the performances under several cases.

1 Original code
image

2 device synchronize
image

3 streams synchronize
image

Overall, adding synchronization makes the code slightly slower than the original, but it avoids bugs. Stream-level synchronization shows some improvement compared to device-level synchronization. Stream-level synchronization might be more precise and can also solve the issue, so I made some changes to the code.

@hwchen2017
Copy link
Contributor

Hi @lpnpcs , can you share your full repo - including source code, dataset, and launch script? I’d be happy to help investigate further if I can reproduce the issue.

@lpnpcs
Copy link
Contributor Author

lpnpcs commented Jun 23, 2025

Hi @lpnpcs , can you share your full repo - including source code, dataset, and launch script? I’d be happy to help investigate further if I can reproduce the issue.

Sorry, our dataset is somewhat sensitive. Everything except the dataset is public, and we used llamafactory to fine-tune qwen2.5-vl-7b. One characteristic of my dataset is that it contains a few pieces of dirty data, which might cause the grad_norm to become extremely large. At this point, the gradient turns into NaN, and then deepspeed processes it as -1. However, using zero3 or disabling overlap_comm or contiguous_gradients resolves the issue.

@hwchen2017
Copy link
Contributor

Can you enable comm_overlap and run your code with CUDA sanitizer and share the output if any?
TORCH_CUDA_SANITIZER=1 python your_code.py

FYI: https://pytorch.org/docs/stable/cuda._sanitizer.html

@sfc-gh-truwase
Copy link
Collaborator

@lpnpcs, please fix conflict. Thanks!

@jhwei
Copy link

jhwei commented Jul 18, 2025

I found similar issue recently while training Qwen2.5-VL-&B too. My solution is pretty similar.

I think the key issue is discovered in some issues like #5545 and #5606, that is the default stream should wait for the reduction stream.

PR 5606 claimed to have fixed this issue but it didn't. It does fix in some cases. but I think the default stream should wait for the reduction stream at the end of reduce_ipg_grads because the latter code in reduce_independent_p_g_buckets_and_remove_grads modifies the buffer at the same time.

This PR fix the issue as well as it syncronize after reduce_ipg_grads. I think the reduction stream does not need to wait for the default stream(current stream) at this place as the default stream does not modify the buffer at the same time.

@loadams
Copy link
Collaborator

loadams commented Jul 28, 2025

@lpnpcs - would you be able to resolve the conflicts on this and we can get it merged?

@lpnpcs
Copy link
Contributor Author

lpnpcs commented Jul 29, 2025

@lpnpcs - would you be able to resolve the conflicts on this and we can get it merged?

Done!

@loadams
Copy link
Collaborator

loadams commented Jul 29, 2025

@lpnpcs - would you be able to resolve the conflicts on this and we can get it merged?

Done!

Thanks, @lpnpcs - could you resolve the formatting fixes as well then we can merge?

@lpnpcs
Copy link
Contributor Author

lpnpcs commented Jul 30, 2025

@lpnpcs - would you be able to resolve the conflicts on this and we can get it merged?

Done!

Thanks, @lpnpcs - could you resolve the formatting fixes as well then we can merge?

Done.

@loadams loadams enabled auto-merge (squash) August 4, 2025 18:20
@loadams loadams merged commit f897b67 into deepspeedai:master Aug 4, 2025
9 checks passed
@GuCarpenter
Copy link

In this way, we have a synchronize point before reduction and after reduction, which means overlap_comm not works anymore, is this the proper solution?

LYMDLUT pushed a commit to LYMDLUT/DeepSpeed that referenced this pull request Aug 20, 2025
I found that when using DeepSpeed Zero2 for my training task, the loss
becomes 0 at the third step with a grad_norm of 1.414. This issue
doesn't occur when using Zero3. I found the same issue deepspeedai#7188. After
conducting a series of experiments, I identified the cause: there's a
synchronization problem when using double ipg_buffer swapping. The issue
was resolved after making modifications.

before

![image](https://github.com/user-attachments/assets/981d0829-e15f-4899-ae2c-4eca16ef138d)

after

![image](https://github.com/user-attachments/assets/8b6b8403-d5df-4aa8-b573-195b9ee1fdfb)

Signed-off-by: vinceliu <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Hongwei Chen <[email protected]>
Signed-off-by: lym <[email protected]>
@sfc-gh-truwase
Copy link
Collaborator

@anyinlover thanks for raising concerns about overlap_comm perf. Do you have any numbers that show degradation?

@lpnpcs and @jhwei I wonder if you have any thoughts on this? Did you measure with overlap_comm in your experiments?

mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
I found that when using DeepSpeed Zero2 for my training task, the loss
becomes 0 at the third step with a grad_norm of 1.414. This issue
doesn't occur when using Zero3. I found the same issue deepspeedai#7188. After
conducting a series of experiments, I identified the cause: there's a
synchronization problem when using double ipg_buffer swapping. The issue
was resolved after making modifications.

before 

![image](https://github.com/user-attachments/assets/981d0829-e15f-4899-ae2c-4eca16ef138d)

after

![image](https://github.com/user-attachments/assets/8b6b8403-d5df-4aa8-b573-195b9ee1fdfb)

Signed-off-by: vinceliu <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Hongwei Chen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants