Skip to content

Conversation

@zdevito
Copy link
Contributor

@zdevito zdevito commented Mar 26, 2024

Stack from ghstack (oldest at bottom):

Adding an event cache has two goals:
(1) lower the overhead of of issuing collectives
(2) removes cudaEventDestroy from the watchdog thread.
If cuda gets stuck due to nccl, then cudaEventDestroy might
hang. This has traditionally gotten the watchdog thread stuck,
causing us to rely on a separate thread to make sure the watchdog
thread makes progress. With this change, we probably do not need that
thread anymore, but we can first check to see if we continue to find
any stack traces suggesting a heartbeat timeout after we land this change.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @rohan-varma

Adding an event cache has two goals:
(1) lower the overhead of of issuing collectives
(2) removes cudaEventDestroy from the watchdog thread.
    If cuda gets stuck due to nccl, then cudaEventDestroy might
    hang. This has traditionally gotten the watchdog thread stuck,
    causing us to rely on a separate thread to make sure the watchdog
    thread makes progress. With this change, we probably do not need that
    thread anymore, but we can first check to see if we continue to find
    any stack traces suggesting a heartbeat timeout after we land this change.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/122732

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c2b2115 with merge base 29132c2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Mar 26, 2024
zdevito added a commit that referenced this pull request Mar 26, 2024
Adding an event cache has two goals:
(1) lower the overhead of of issuing collectives
(2) removes cudaEventDestroy from the watchdog thread.
    If cuda gets stuck due to nccl, then cudaEventDestroy might
    hang. This has traditionally gotten the watchdog thread stuck,
    causing us to rely on a separate thread to make sure the watchdog
    thread makes progress. With this change, we probably do not need that
    thread anymore, but we can first check to see if we continue to find
    any stack traces suggesting a heartbeat timeout after we land this change.

ghstack-source-id: 1d05496
Pull Request resolved: #122732
@zdevito zdevito requested a review from shuqiangzhang March 26, 2024 22:07
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this LGTM. do you have some testing/instrumentation to confirm it actually works as intended?

return resultFuture;
}

class CUDAEventCache {
Copy link
Contributor

@shuqiangzhang shuqiangzhang Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by convention, Can we move the class declaration to .hpp file so that we can access the class anywhere in the cpp file ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProcessGroupNCCL.hpp is included in 5 compilation units, whereas code in ProcessGroupNCCL.cpp only is in one, so in cases where the code is only used locally I tend to avoid the header files to keep compile times down.

at::cuda::CUDAEvent* event = nullptr;
{
std::lock_guard<std::mutex> lock(deviceCache.mutex);
auto& events = deviceCache.events[timing ? 1 : 0];
Copy link
Contributor

@shuqiangzhang shuqiangzhang Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am reading it wrong, Is deviceCache.events[timing ? 1 : 0] a vector of event* or just only 1 'event*?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deviceCache.events is two vector: a list of unused events without timing, and a list of unused events with timing. Since timing isn't strictly a global property (you can force timing on for a particular process group), we need to handle the situation where we fulfill both.

Copy link
Contributor

@shuqiangzhang shuqiangzhang Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I got it now, maybe it is just to me, but this was the confusing part to me: std::vector<at::cuda::CUDAEvent*> events[2], it is using a combination of vector and array semantics. Maybe a vector of vector is better.


class CUDAEventCache {
public:
CUDAEventCache() : caches_(at::cuda::device_count()) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ProcessGroupNCCL now supports single device per thread only. Would that help make the implementation here even simpler?

}
ncclEndEvent_ = std::make_shared<at::cuda::CUDAEvent>(
enableTiming ? cudaEventDefault : cudaEventDisableTiming);
ncclEndEvent_ = CUDAEventCache::get().create(device.index(), enableTiming);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any idea how big a performance difference there will be between disableTiming and enableTiming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is one analysis I've seen:

https://github.com/harrism/cuda_event_benchmark

Using the items_per_second column, recording an event with timing is 2.5us but without timing it is .25us. Similarly it is about ~10x to cache the events vs cudaEventCreate/Destroy each time.

@zdevito
Copy link
Contributor Author

zdevito commented Mar 27, 2024

I tested this locally with prints to ensure I see events getting reused. But I am looking for ideas of how to actually test it better.

Adding an event cache has two goals:
(1) lower the overhead of of issuing collectives
(2) removes cudaEventDestroy from the watchdog thread.
    If cuda gets stuck due to nccl, then cudaEventDestroy might
    hang. This has traditionally gotten the watchdog thread stuck,
    causing us to rely on a separate thread to make sure the watchdog
    thread makes progress. With this change, we probably do not need that
    thread anymore, but we can first check to see if we continue to find
    any stack traces suggesting a heartbeat timeout after we land this change.

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang

[ghstack-poisoned]
@pritamdamania87
Copy link
Contributor

(2) removes cudaEventDestroy from the watchdog thread.
If cuda gets stuck due to nccl, then cudaEventDestroy might
hang. This has traditionally gotten the watchdog thread stuck,
causing us to rely on a separate thread to make sure the watchdog
thread makes progress. With this change, we probably do not need that
thread anymore, but we can first check to see if we continue to find
any stack traces suggesting a heartbeat timeout after we land this change.

This isn't a reliable or general way to the solve the problem mentioned in #101463. As mentioned in #101463, the watchdog thread also calls things like isCompleted(): https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1647 which eventually call things like cudaGetLastError which can get stuck too. There are probably a lot of CUDA calls that occur which might not be easy to find and track down in the watchdog thread. That is why a separate minimalistic thread can avoid such situations more reliably.

Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@github-actions
Copy link
Contributor

github-actions bot commented Jun 3, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 3, 2024
@github-actions github-actions bot closed this Jul 3, 2024
@github-actions github-actions bot deleted the gh/zdevito/260/head branch August 2, 2024 01:56
fduwjj added a commit that referenced this pull request Aug 19, 2024
zdevito added a cache for CudaEvent in #122732. And we want to productionize it with a flag in this PR.

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Stale

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants