Skip to content

Conversation

@yifuwang
Copy link
Collaborator

@yifuwang yifuwang commented Oct 12, 2024

Stack from ghstack (oldest at bottom):

Parallelization strategy: every rank issues independent compute
-> barrier -> p2p copy sequences on two streams. In addition to
computation/communication overlapping, the strategy allows for
computation/computation overlapping, greatly reducing
quantization inefficiency.

Ideally, stream activities would look like this ("b" for
barriers, "cp" for p2p copies):

[rank 0]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

Note that the barriers synchronize streams with the same ID
across ranks. They don't synchronize streams on the same rank.

Since the work on both streams is independent, there's no
guarantee that the chunk_producer from stream 0 or stream 1 will
be scheduled first. If there is a scheduling mismatch across
ranks, the barrier forces all ranks to wait for the slowest.

When scheduling mismatches occur among ranks, the stream
activities might look like this (note that p2p copies from
different streams cannot overlap with each other):

[rank 0]
stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]

To prevent this, we need to ensure that the chunk_producer on
stream 1 gets scheduled first on every rank. Without access to
the underlying kernels, CUDA offers no API to control the
scheduling order of two independent, overlapping kernels. Our
solution is to issue a small sleep kernel in stream 0. The sleep
duration is insignificant, but having an extra task in stream 0
will almost guarantee that the chunk_producer on stream 1 gets
scheduled first. Once the first chunk_producer is scheduled in
the correct order, there's very little room for the scheduling
order of subsequent kernels to be inconsistent across ranks.

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
image

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
image

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
image

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137836

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a6f9232 with merge base dae6007 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 12, 2024
@yifuwang yifuwang requested review from Chillee, lw and weifengpy October 12, 2024 04:39
@yifuwang yifuwang added the topic: not user facing topic category label Oct 12, 2024
```
# Our parallelization strategy issues independent compute ->
# barrier -> p2p copy sequences on two streams. This strategy not
# only allows computation and p2p copy from different streams to
# overlap, it also allows computations in different streams to
# overlap, greatly reducing quantization inefficiency.
#
# Ideally, stream activities would look like this ("b" for
# barriers, "cp" for p2p copies):
#
# [rank 0]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# [rank 1]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# NOTE: the barriers synchronize streams with the same ID across
# ranks. They don't synchronize streams on the same rank.
#
# Since the work on both streams is independent, there's no
# guarantee that the chunk_producer from stream 0 or stream 1 will
# be scheduled first. If there is a scheduling mismatch across
# ranks, the barrier forces all ranks to wait for the slowest.
#
# When scheduling mismatches occur among ranks, the stream
# activities might look like this (note that p2p copies from
# different streams cannot overlap with each other):
#
# [rank 0]
# stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
# stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
#
# <rank 1>
# stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
# stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]
#
# To prevent this, we need to ensure that the chunk_producer on
# stream 1 runs first on every rank. Without access to the
# underlying kernels, CUDA offers no API to control the scheduling
# order of two independent, overlapping kernels. Our solution is to
# issue a small sleep kernel in stream 0. The sleep duration is
# insignificant. The sleep task itself will drastically increase
# the chance that the chunk_producer on stream 1 gets scheduled
# first. As long as the first chunk_producer is scheduled in the
# correct order, the likelihood of inconsistent scheduling orders
# for subsequent kernels across ranks will be minimal.
```

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
<img width="379" alt="image" src="https://github.com/user-attachments/assets/ffb97e76-7e19-4449-b121-83c32ec3e91d">

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
<img width="378" alt="image" src="https://github.com/user-attachments/assets/4cb76246-625f-4fc1-b49a-823ae46d3f23">

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
<img width="327" alt="image" src="https://github.com/user-attachments/assets/51ab1bdc-4f60-46e0-b53c-6d208e2d4888">


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
```
# Our parallelization strategy issues independent compute ->
# barrier -> p2p copy sequences on two streams. This strategy not
# only allows computation and p2p copy from different streams to
# overlap, it also allows computations in different streams to
# overlap, greatly reducing quantization inefficiency.
#
# Ideally, stream activities would look like this ("b" for
# barriers, "cp" for p2p copies):
#
# [rank 0]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# [rank 1]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# NOTE: the barriers synchronize streams with the same ID across
# ranks. They don't synchronize streams on the same rank.
#
# Since the work on both streams is independent, there's no
# guarantee that the chunk_producer from stream 0 or stream 1 will
# be scheduled first. If there is a scheduling mismatch across
# ranks, the barrier forces all ranks to wait for the slowest.
#
# When scheduling mismatches occur among ranks, the stream
# activities might look like this (note that p2p copies from
# different streams cannot overlap with each other):
#
# [rank 0]
# stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
# stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
#
# <rank 1>
# stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
# stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]
#
# To prevent this, we need to ensure that the chunk_producer on
# stream 1 runs first on every rank. Without access to the
# underlying kernels, CUDA offers no API to control the scheduling
# order of two independent, overlapping kernels. Our solution is to
# issue a small sleep kernel in stream 0. The sleep duration is
# insignificant. The sleep task itself will drastically increase
# the chance that the chunk_producer on stream 1 gets scheduled
# first. As long as the first chunk_producer is scheduled in the
# correct order, the likelihood of inconsistent scheduling orders
# for subsequent kernels across ranks will be minimal.
```

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
<img width="379" alt="image" src="https://github.com/user-attachments/assets/ffb97e76-7e19-4449-b121-83c32ec3e91d">

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
<img width="378" alt="image" src="https://github.com/user-attachments/assets/4cb76246-625f-4fc1-b49a-823ae46d3f23">

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
<img width="327" alt="image" src="https://github.com/user-attachments/assets/51ab1bdc-4f60-46e0-b53c-6d208e2d4888">


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 12, 2024
```
# Our parallelization strategy issues independent compute ->
# barrier -> p2p copy sequences on two streams. This strategy not
# only allows computation and p2p copy from different streams to
# overlap, it also allows computations in different streams to
# overlap, greatly reducing quantization inefficiency.
#
# Ideally, stream activities would look like this ("b" for
# barriers, "cp" for p2p copies):
#
# [rank 0]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# [rank 1]
# stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
# stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]
#
# NOTE: the barriers synchronize streams with the same ID across
# ranks. They don't synchronize streams on the same rank.
#
# Since the work on both streams is independent, there's no
# guarantee that the chunk_producer from stream 0 or stream 1 will
# be scheduled first. If there is a scheduling mismatch across
# ranks, the barrier forces all ranks to wait for the slowest.
#
# When scheduling mismatches occur among ranks, the stream
# activities might look like this (note that p2p copies from
# different streams cannot overlap with each other):
#
# [rank 0]
# stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
# stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
#
# [rank 1]
# stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
# stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]
#
# To prevent this, we need to ensure that the chunk_producer on
# stream 1 runs first on every rank. Without access to the
# underlying kernels, CUDA offers no API to control the scheduling
# order of two independent, overlapping kernels. Our solution is to
# issue a small sleep kernel in stream 0. The sleep duration is
# insignificant. The sleep task itself will drastically increase
# the chance that the chunk_producer on stream 1 gets scheduled
# first. As long as the first chunk_producer is scheduled in the
# correct order, the likelihood of inconsistent scheduling orders
# for subsequent kernels across ranks will be minimal.
```

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
<img width="379" alt="image" src="https://github.com/user-attachments/assets/ffb97e76-7e19-4449-b121-83c32ec3e91d">

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
<img width="378" alt="image" src="https://github.com/user-attachments/assets/4cb76246-625f-4fc1-b49a-823ae46d3f23">

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
<img width="327" alt="image" src="https://github.com/user-attachments/assets/51ab1bdc-4f60-46e0-b53c-6d208e2d4888">


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
```
Parallelization strategy: every rank issues independent compute
-> barrier -> p2p copy sequences on two streams. In addition to
computation/communication overlapping, the strategy allows for
computation/computation overlapping, greatly reducing
quantization inefficiency.

Ideally, stream activities would look like this ("b" for
barriers, "cp" for p2p copies):

[rank 0]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b][ cp ][  chunk_producer ][b][ cp ]
stream 1: [  chunk_producer  ][b][ cp ][  chunk_producer  ][b][ cp ]

Note that the barriers synchronize streams with the same ID
across ranks. They don't synchronize streams on the same rank.

Since the work on both streams is independent, there's no
guarantee that the chunk_producer from stream 0 or stream 1 will
be scheduled first. If there is a scheduling mismatch across
ranks, the barrier forces all ranks to wait for the slowest.

When scheduling mismatches occur among ranks, the stream
activities might look like this (note that p2p copies from
different streams cannot overlap with each other):

[rank 0]
stream 0: [  chunk_producer  ][b        ][ cp ][  chunk_producer ][b       ][ cp ]
stream 1:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]

[rank 1]
stream 0:         [  chunk_producer  ][b]      [ cp ][  chunk_producer  ][b]      [ cp ]
stream 1: [  chunk_producer  ][b        ][ cp ][  chunk_producer  ][b      ][ cp ]

To prevent this, we need to ensure that the chunk_producer on
stream 1 gets scheduled first on every rank. Without access to
the underlying kernels, CUDA offers no API to control the
scheduling order of two independent, overlapping kernels. Our
solution is to issue a small sleep kernel in stream 0. The sleep
duration is insignificant, but having an extra task in stream 0
will almost guarantee that the chunk_producer on stream 1 gets
scheduled first. Once the first chunk_producer is scheduled in
the correct order, there's very little room for the scheduling
order of subsequent kernels to be inconsistent across ranks.
```

Currently, we perform stream synchronization to ensure scheduling order. The stream synchronization has no bearing on correctness, but prevents inconsistent scheduling orders across ranks.

Without the stream synchronization, ranks may have inconsistent scheduling order, and the barriers cause all ranks to wait for the slowest rank:
<img width="379" alt="image" src="https://github.com/user-attachments/assets/ffb97e76-7e19-4449-b121-83c32ec3e91d">

With stream synchronization, the inconsistent scheduling order issue is addressed, but we lose compute/compute overlapping (this is the state before this PR):
<img width="378" alt="image" src="https://github.com/user-attachments/assets/4cb76246-625f-4fc1-b49a-823ae46d3f23">

With this PR, we get both consistent scheduling order across ranks and compute/compute overlap:
<img width="327" alt="image" src="https://github.com/user-attachments/assets/51ab1bdc-4f60-46e0-b53c-6d208e2d4888">


cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Oct 15, 2024
…37850)

```
Parallelization strategy: after each rank copies its shard into its local
p2p buffer, every rank issues independent p2p copy -> shard_consumer
sequences to two streams. In addition to computation/communication
overlapping, the strategy allows for computation/computation overlapping,
greatly reducing quantization inefficiency.

Notation:
- "mv" for the copy to local buffer
- "cp" for p2p copies
- "b" for barriers

Constraints:
- The GPU scheduler may or may not overlap "mv" with the first shard_consumer.
- "cp" from different streams cannot overlap.

Ideal scenario 0 - "mv" overlaps with the first shard_consumer:

stream 0: [ shard_consumer ][ cp ][ shard_consumer ]
stream 1: [ mv ][b][ cp ][ shard_consumer ]

Ideal scenario 1 - "mv" is scheduled before the first shard_consumer:

stream 0:       [ shard_consumer ][ cp ][ shard_consumer ]
stream 1: [ mv ][b][ cp ][ shard_consumer ]

Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer:

stream 0: [ shard_consumer ]               [ cp ][ shard_consumer ]
stream 1:                   [ mv ][b][ cp ][ shard_consumer ]

Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer:

stream 0:       [ shard_consumer ]         [ cp ][ shard_consumer ]
stream 1: [ mv ]                  [b][ cp ][ shard_consumer ]

We haven't yet figured out a way to ensure "mv" and "b" are either
overlapped with or scheduled before the first shard_consumer. Thus, to
prevent suboptimal scenarios, we are giving up the chance to overlap "mv"
and "b" with the first shard_consumer for now.
```

This PR improves the scheduling for mm kernels with high SM utilization. The GPU scheduler tends to not overlap local DtoD copies with such kernels, which leads to suboptimal scheduling. The following is an example of pipelining PyTorch's cutlass-based, row-wise scaling fp8 kernel:

Before this PR:
<img width="298" alt="image" src="https://github.com/user-attachments/assets/81e0a7f4-18ee-47c6-b258-04fdaca7a6a2">

With this PR:
<img width="253" alt="image" src="https://github.com/user-attachments/assets/982de5a8-da1e-4a8f-b67e-c9c869b0a77f">

Pull Request resolved: #137850
Approved by: https://github.com/weifengpy
ghstack dependencies: #137643, #137738, #137805, #137836
@github-actions github-actions bot deleted the gh/yifuwang/145/head branch November 15, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants