Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Sep 28, 2024

Stack from ghstack (oldest at bottom):

Performs shape inference at runtime using user-provided real tensors.

  • avoids the need for users to precompute shapes which is difficult and error prone
  • lets us remove args from the PipelineStage ctor (in a later PR)
  • deprecates existing inference helper in PipelineStage constructor for several reasons: its problematic to have to reason about the stage submod being on the right device for shape inference

The current state as of this PR:

  • Users should not pass any input or output shapes into PipelineStage ctor, and shape inference will run automatically
  • To override shape inference, they can continue to pass input/output args as previously

Currently, does not add a barrier after shape-inference, which essentially pipelines shape inference with the subsequent schedule action for that stage. If this complicates debugging, we could add in a barrier (it comes at a cost, but only during the first step).

Testing:

  • Removed input args from all PP test cases, thus exposing them all to shape-inference.
  • Verified visually (nvidia-smi) that torchtitan PP 3D test runs shape inference fine without creating extra cuda contexts.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136912

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 39 Pending

As of commit 577e662 with merge base failed to retrieve merge base, please contact dev infra:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 28, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Sep 28, 2024
TODO: fix extra context on dev0 (recv_obj_list?)

ghstack-source-id: c594399
Pull Request resolved: #136912
@kwen2501 kwen2501 added release notes: distributed (pipeline) release notes category module: pipelining Pipeline Parallelism labels Sep 28, 2024
@kwen2501 kwen2501 requested review from H-Huang and kwen2501 and removed request for kwen2501 September 28, 2024 18:07
num_stages: int,
device: torch.device,
input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
input_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we take this chance to make the signature right?
It should always be Tuple[torch.Tensor, ...] even if there is only 1 tensor input, not Union.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, do we have a pattern to follow here? i thought i remembered other torch methods accepting a single item or a list/tuple of items in some cases. But i would agree with your suggestion. I worry a bit about BC breaking, but i guess that's probably not in widespread use yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can work on this but i would prefer to keep it a separate PR

"or additionally pass `output_args` to `PipelineStage` to fully override shape inference. "
)
self.inputs_meta = (
(input_args,) if isinstance(input_args, torch.Tensor) else input_args
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the signature is unified, then we don't have this if-else

assert num_microbatches is not None, "TODO fix num_microbatches"

outputs: Tuple[Any, ...] = tuple()
if self.inputs_meta is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, would this if causes unmatched behaviors between ranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I am making an assumption here that every rank either has shapes or doesn't have shapes provided for the Stage ctor, but we don't have enforcement.

i guess we could do a synchronization to agree on whether all ranks are doing shape inference before we do it. But that would lead to another question of when- (a) in init feels bad to me (i never like to do comm ops in init) (b) but at step() time we'd get stuck doing it on every step which is too expensive, or else we'd need the 'if initialized' logic to skip it, which in theory could desync.

[ghstack-poisoned]
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Oct 9, 2024
TODO: fix extra context on dev0 (recv_obj_list?)

ghstack-source-id: 5c3af43
Pull Request resolved: #136912
[ghstack-poisoned]
@wconstab wconstab added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2024
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Oct 9, 2024
TODO: fix extra context on dev0 (recv_obj_list?)

ghstack-source-id: 00624a6
Pull Request resolved: #136912
Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks! Minor comments.

Comment on lines +1103 to +1110
if stage.is_first:
next_stage_args = stage._prepare_forward_infra(
self._n_microbatches, args, kwargs
)
else:
next_stage_args = stage._prepare_forward_infra(
self._n_microbatches, next_stage_args, kwargs
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the assumption is that stages on the same rank are consecutive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, good question. let me clarify with more comments in the code. The assumption should be this:

  1. the stages on a rank are always sorted, but not necessarily consecutive. (e.g. could be (0, 4, 8) or could be (3, 4)
  2. IF the stages are consecutive (e.g. 3, 4) then they should communicate via next_stage_args
  3. ELSE the stages should have 'none' values for next_stage_args and should communicate via send/recv

assert args is not None, "Args may be an empty tuple but not None"
if (
self.is_first
or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment?

# pass their shape info via return value and function args rather than send/recv.
if (
self.is_last
or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment?

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome change!

Just a thought, if we wanted to support skip connections we could leverage this and create a more general metadata format that includes the shapes from previous ranks as well as previous stages.

stage_idx,
n_stages,
self.device,
input_args=input_args,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests left with input_args included so that path remains tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, i thought about splitting it up that way but since I marked it as deprecated I decided to go all in on covering the shape inference path instead. I can add one back in though, its a good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to add coverage for the case where we pass both input and output, since that's supported as an override to disable shape inference. I left the input_args only case untested. I wonder if we can just move quickly to stop supporting that path? Do we need a deprecation cycle or should we just flip it? I can put up another PR for this change and we can discuss there.


# communicate meta outputs not real outputs for two reasons
# 1 - its faster (esp. since obj coll pickles tensor data!)
# 2 - avoid activating a cuda context for the src rank when unpickling on the recv end!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, can you elaborate on the second point? why is cuda context needed on the src rank if we were to send tensor data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't actually confirm this, but my suspicion is that if we pickle a tensor on 'cuda:0', serialize it into bytes, then send it to rank1, when rank1 unpickles it, it would reconstruct the original metadata and create a tensor on 'cuda:0' which is not what we want (and i assume would create a new context).

@wconstab
Copy link
Contributor Author

Just a thought, if we wanted to support skip connections we could leverage this and create a more general metadata format that includes the shapes from previous ranks as well as previous stages.

not sure i follow how the metadata format would work. do we want to support skip connections? it does seem a bit complicated for the overall stack to support this.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@wconstab
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants