Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Oct 25, 2024

Stack from ghstack (oldest at bottom):

The special case was added during experimentation with batched send/recv
ops. The ops needed to be jointly scheduled or the simulator would
think that each op was unschedulable since each contained a recv that
depended on the other's send. The workaround I added was to let the
scheduler 'peek' one op ahead for unblocking, which let batched ops be
scheduled but also changed the behavior or non-batched ops. It let RECV
ops be simulated one step earlier than the unblocking SEND ops, which
shortened the simulated duration of schedules.

Removing this workaround simplifies the simulator but more importantly
lends to optimizing the runtime of the simulator by making it much
easier to avoid copying or extending lists of previous ops on each
iteration. It also restores the output of the simulator for non-batched
ops to a more natural output where RECV must happen at the same time or
later than matching SEND, rather than possibly a step earlier.

For example, for this test:
python test/distributed/pipelining/test_schedule.py -k test_send_recv_test_info0

Before:

Step 0: 0F0      1RECV_F0
Step 1: 0SEND_F0
Step 2: 0F1      1RECV_F1
Step 3: 0SEND_F1 1F0
Step 4: 0RECV_B0 1B0
Step 5: 0B0      1SEND_B0
Step 6:          1F1
Step 7: 0RECV_B1 1B1
Step 8: 0B1      1SEND_B1

After:

Rank 0   Rank 1
Step 00: 0F0
Step 01: 0SEND_F0 1RECV_F0
Step 02: 0F1
Step 03: 0SEND_F1 1RECV_F1
Step 04:          1F0
Step 05:          1B0
Step 06: 0RECV_B0 1SEND_B0
Step 07: 0B0      1F1
Step 08:          1B1
Step 09: 0RECV_B1 1SEND_B1
Step 10: 0B1

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 25, 2024
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138928

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit aa83a79 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@wconstab wconstab added release notes: distributed (pipeline) release notes category module: pipelining Pipeline Parallelism labels Oct 30, 2024
@wconstab wconstab requested review from H-Huang and kwen2501 and removed request for H-Huang October 30, 2024 19:44
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2024
### Separate dI / dW:

PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.

Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles.  In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.

### V-schedules:

V-schedules have a special case where the last rank has 2 adjacent
stages.

E.g. if rank3 had stage 3 and stage 4, then we should implement direct
transfer of stage3 outputs to stage4 inputs without a
send/recv.

In the schedling logic, we also must allow scheduling the
stage 4 forward after running stage 3 forward, without expecting a stage
4 RECV_F

In the runtime, we pass activations between adjacent stages without
using SEND/RECV ops since the stages are on the same rank/process.  We
add new APIs to PipelineStage abstraction for passing the activations
both during forward and backward.  Currently the implementation directly
modifies the 'recv buffers' the stage is managing, so the
forward/backwrad execution logic does not need to know the difference.
Pull Request resolved: #131762
Approved by: https://github.com/H-Huang
ghstack dependencies: #138928
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2024
Used in both simulator and add_send_recv pass, the ready_to_schedule
logic works by looking at all the previously scheduled ops on a rank to
see if any of them 'unblocks' the current op to be scheduled.  For example,
to schedule a FORWARD op, a previous RECV_F op is needed, unless this is
stage 0 or there is a previous stage on the same rank that ran FORWARD
already.

The old implementation iteratively compared the candidate op to the
previous ops.  The new implementation uses set lookups to reduce
complexity.  It also maintains the set of previous ops as ops are
scheduled rather than constructing a set on demand.

I did not save benchmark results, but this results in a 10-100x speedup
which is most noticeable for unit tests with artificially huge schedule
IR, the largest of which took longer than 20m before (I never let it
finish) but now takes less than 14s.  Most schedules take less than
10ms.

Pull Request resolved: #138924
Approved by: https://github.com/H-Huang
ghstack dependencies: #138928, #131762
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
The special case was added during experimentation with batched send/recv
ops.  The ops needed to be jointly scheduled or the simulator would
think that each op was unschedulable since each contained a recv that
depended on the other's send.  The workaround I added was to let the
scheduler 'peek' one op ahead for unblocking, which let batched ops be
scheduled but also changed the behavior or non-batched ops.  It let RECV
ops be simulated one step earlier than the unblocking SEND ops, which
shortened the simulated duration of schedules.

Removing this workaround simplifies the simulator but more importantly
lends to optimizing the runtime of the simulator by making it much
easier to avoid copying or extending lists of previous ops on each
iteration.  It also restores the output of the simulator for non-batched
ops to a more natural output where RECV must happen at the same time or
later than matching SEND, rather than possibly a step earlier.

For example, for this test:
`python test/distributed/pipelining/test_schedule.py -k test_send_recv_test_info0`

Before:

```
Step 0: 0F0      1RECV_F0
Step 1: 0SEND_F0
Step 2: 0F1      1RECV_F1
Step 3: 0SEND_F1 1F0
Step 4: 0RECV_B0 1B0
Step 5: 0B0      1SEND_B0
Step 6:          1F1
Step 7: 0RECV_B1 1B1
Step 8: 0B1      1SEND_B1
```

After:
```
Rank 0   Rank 1
Step 00: 0F0
Step 01: 0SEND_F0 1RECV_F0
Step 02: 0F1
Step 03: 0SEND_F1 1RECV_F1
Step 04:          1F0
Step 05:          1B0
Step 06: 0RECV_B0 1SEND_B0
Step 07: 0B0      1F1
Step 08:          1B1
Step 09: 0RECV_B1 1SEND_B1
Step 10: 0B1
```

Pull Request resolved: pytorch#138928
Approved by: https://github.com/H-Huang
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
### Separate dI / dW:

PipelineScheduleRuntime now supports execution of merged FULL_BACKWARD
or separate dI / dW operations.

Separating the B and W may add execution overhead or may be suboptimal
in cases where BW are 'fused', but it is worthwhile when separating B, W
lets the schedule be more efficient by filling in bubbles.  In some
cases, the schedule will still issue B followed by W at certain points,
so in these cases just merge them back into BW ops and execute them as
full backwards rather than executing a B followed by a W.

### V-schedules:

V-schedules have a special case where the last rank has 2 adjacent
stages.

E.g. if rank3 had stage 3 and stage 4, then we should implement direct
transfer of stage3 outputs to stage4 inputs without a
send/recv.

In the schedling logic, we also must allow scheduling the
stage 4 forward after running stage 3 forward, without expecting a stage
4 RECV_F

In the runtime, we pass activations between adjacent stages without
using SEND/RECV ops since the stages are on the same rank/process.  We
add new APIs to PipelineStage abstraction for passing the activations
both during forward and backward.  Currently the implementation directly
modifies the 'recv buffers' the stage is managing, so the
forward/backwrad execution logic does not need to know the difference.
Pull Request resolved: pytorch#131762
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#138928
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Used in both simulator and add_send_recv pass, the ready_to_schedule
logic works by looking at all the previously scheduled ops on a rank to
see if any of them 'unblocks' the current op to be scheduled.  For example,
to schedule a FORWARD op, a previous RECV_F op is needed, unless this is
stage 0 or there is a previous stage on the same rank that ran FORWARD
already.

The old implementation iteratively compared the candidate op to the
previous ops.  The new implementation uses set lookups to reduce
complexity.  It also maintains the set of previous ops as ops are
scheduled rather than constructing a set on demand.

I did not save benchmark results, but this results in a 10-100x speedup
which is most noticeable for unit tests with artificially huge schedule
IR, the largest of which took longer than 20m before (I never let it
finish) but now takes less than 14s.  Most schedules take less than
10ms.

Pull Request resolved: pytorch#138924
Approved by: https://github.com/H-Huang
ghstack dependencies: pytorch#138928, pytorch#131762
@github-actions github-actions bot deleted the gh/wconstab/353/head branch December 1, 2024 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants