-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[pipelining] Shape Inference #136912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pipelining] Shape Inference #136912
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136912
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 39 PendingAs of commit 577e662 with merge base failed to retrieve merge base, please contact dev infra: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| num_stages: int, | ||
| device: torch.device, | ||
| input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], | ||
| input_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we take this chance to make the signature right?
It should always be Tuple[torch.Tensor, ...] even if there is only 1 tensor input, not Union.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, do we have a pattern to follow here? i thought i remembered other torch methods accepting a single item or a list/tuple of items in some cases. But i would agree with your suggestion. I worry a bit about BC breaking, but i guess that's probably not in widespread use yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can work on this but i would prefer to keep it a separate PR
| "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " | ||
| ) | ||
| self.inputs_meta = ( | ||
| (input_args,) if isinstance(input_args, torch.Tensor) else input_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the signature is unified, then we don't have this if-else
| assert num_microbatches is not None, "TODO fix num_microbatches" | ||
|
|
||
| outputs: Tuple[Any, ...] = tuple() | ||
| if self.inputs_meta is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, would this if causes unmatched behaviors between ranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, I am making an assumption here that every rank either has shapes or doesn't have shapes provided for the Stage ctor, but we don't have enforcement.
i guess we could do a synchronization to agree on whether all ranks are doing shape inference before we do it. But that would lead to another question of when- (a) in init feels bad to me (i never like to do comm ops in init) (b) but at step() time we'd get stuck doing it on every step which is too expensive, or else we'd need the 'if initialized' logic to skip it, which in theory could desync.
kwen2501
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks! Minor comments.
| if stage.is_first: | ||
| next_stage_args = stage._prepare_forward_infra( | ||
| self._n_microbatches, args, kwargs | ||
| ) | ||
| else: | ||
| next_stage_args = stage._prepare_forward_infra( | ||
| self._n_microbatches, next_stage_args, kwargs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the assumption is that stages on the same rank are consecutive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, good question. let me clarify with more comments in the code. The assumption should be this:
- the stages on a rank are always sorted, but not necessarily consecutive. (e.g. could be (0, 4, 8) or could be (3, 4)
- IF the stages are consecutive (e.g. 3, 4) then they should communicate via next_stage_args
- ELSE the stages should have 'none' values for next_stage_args and should communicate via send/recv
| assert args is not None, "Args may be an empty tuple but not None" | ||
| if ( | ||
| self.is_first | ||
| or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment?
| # pass their shape info via return value and function args rather than send/recv. | ||
| if ( | ||
| self.is_last | ||
| or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment?
H-Huang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome change!
Just a thought, if we wanted to support skip connections we could leverage this and create a more general metadata format that includes the shapes from previous ranks as well as previous stages.
| stage_idx, | ||
| n_stages, | ||
| self.device, | ||
| input_args=input_args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any tests left with input_args included so that path remains tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, i thought about splitting it up that way but since I marked it as deprecated I decided to go all in on covering the shape inference path instead. I can add one back in though, its a good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to add coverage for the case where we pass both input and output, since that's supported as an override to disable shape inference. I left the input_args only case untested. I wonder if we can just move quickly to stop supporting that path? Do we need a deprecation cycle or should we just flip it? I can put up another PR for this change and we can discuss there.
|
|
||
| # communicate meta outputs not real outputs for two reasons | ||
| # 1 - its faster (esp. since obj coll pickles tensor data!) | ||
| # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, can you elaborate on the second point? why is cuda context needed on the src rank if we were to send tensor data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i didn't actually confirm this, but my suspicion is that if we pickle a tensor on 'cuda:0', serialize it into bytes, then send it to rank1, when rank1 unpickles it, it would reconstruct the original metadata and create a tensor on 'cuda:0' which is not what we want (and i assume would create a new context).
not sure i follow how the metadata format would work. do we want to support skip connections? it does seem a bit complicated for the overall stack to support this. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Performs shape inference at runtime using user-provided real tensors.
The current state as of this PR:
Currently, does not add a barrier after shape-inference, which essentially pipelines shape inference with the subsequent schedule action for that stage. If this complicates debugging, we could add in a barrier (it comes at a cost, but only during the first step).
Testing:
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @d4l3k @c-p-i-o