-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor #139095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ncCollectiveTensor [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139095
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f4d09da with merge base 5f266b5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if is_subclass and not is_subclass_meta: | ||
| # Unexpected subclass, during tracing we guessed it was a plain Tensor | ||
| if hasattr(x, "__coerce_same_metadata_as_tangent__"): | ||
| x = x.__coerce_same_metadata_as_tangent__(None, torch.Tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm. Two things:
(1) Right now this branch is hardcoded for the case where we expected a subclass tangent but got a plain tensor tangent. But we specifically updated the coerce API to be more generic: it can technically allow a subclass to convert to any other subclass type, if they are able to handle it. It seems to me like if we are going with that more general API, we should properly handle that here: don't haredcode torch.Tensor, just directly pass in the type(x).
Since we're potentially being BC breaking (I think only DTensor uses this API, although I think a few people have forked DTensor out of tree over time), it might also be better to optionally only pass the type argument in when the types are different
(2) you added this call as a completely new one, on top of the existing a call below (x = x.__coerce_same_metadata_as_tangent__(meta.meta)). It seems better to consolidate them into a single call?
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
torch/distributed/tensor/_api.py
Outdated
|
|
||
| def __coerce_same_metadata_as_tangent__(self, flatten_spec): | ||
| def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): | ||
| assert expected_type is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this can be user facing, if the user ends up running their code in such a way that e.g. the expected tangents are DTensors but the actual tangents are plain tensors. So we should make sure the error message if this assert fails is very clear. (include the actual/expected types and metadata)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or alternatively, we could have this function return None if it is unable to properly coerce, and let AOTAutograd raise the error itself if it sees a None return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, with current logic if it returns None we will raise an error with all details. https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1486
I think we should also document the logic of coercing by metadata and type, the meaning of None etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm this code as-written (the assert above) still seem like it will regress our error message?
Prior to this PR, if we expected a DTensor tangent but we got a plain tensor tangent, we would get this nice error message: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1505
Seems like you want to return None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, better to not assert and just return None here for user-friendly error message.
| raise RuntimeError("Not implemented") | ||
|
|
||
| t = self.trigger_wait() | ||
| while isinstance(t, AsyncCollectiveTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we shouldn't need this while loop either (I think we will be a in a very weird place if there are every nested AsyncCollectiveTensors, so it seems pointless to defensively program for it)
| def maybe_coerce_to_memory_format(t, memory_format): | ||
| if not t.is_contiguous(memory_format=meta.memory_format): | ||
| return t.contiguous(memory_format=meta.memory_format) | ||
| return t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
superficially from reading this function, it's not clear what the return type is supposed to be? (here you are returning a single tensor t, while lower down you return x, [x])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe_coerce_to_memory_format is only for memory_format processing of one argument.
The whole function process_runtime_tangent returns tuple to be able to do flatenning (2d item in tuple) at the same traversal as processing, to do only one traversal in total.
So the return type is Tuple[ChangedRawItem, FlattenedChangedItems]
|
|
||
| if not x.is_contiguous(memory_format=meta.memory_format): | ||
| x = x.contiguous(memory_format=meta.memory_format) | ||
| if is_subclass and not is_subclass_meta: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, concretely, I think:
(1) instead of special-casing the is_subclass case, we should just check if the type of the two tensors is different, and if so then unconditionally call coerce(metadata or None, type(actual))
(2) if the coerce() function returns None, the subclass has indicated that it cannot perform the coercion, and so we can raise the old error message (with even more useful information in the error than the subclass could have given, like exactly which tangent in the tangent list we are up to)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Fused the logic in one coercion and updated PR.
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
| from torch.utils._python_dispatch import return_and_correct_aliasing | ||
|
|
||
|
|
||
| class WrapSC(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: up to you, but I find WrapperSubclass a bit clearer than WrapperSC
| def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): | ||
| assert meta is None | ||
| a = inner_tensors["a"] | ||
| if type(a) is torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the purpose of this check is that you only want to run the assertions below if we are at runtime, not trace time.
This isn't very robust though, since it will fail if the inner tensor a is itself another subclass.
You probably want to use is_fake(a), to tell if you are in the middle of tracing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I just copied this from TwoTensor :)
| self, expected_metadata: Any, expected_type: Optional[Type] = None | ||
| ): | ||
| if expected_type is torch.Tensor: | ||
| return self.a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh this code as written is a bit confusing: the idea of this function is that it is supposed to enforce that the return type is the same as expected_type. You might want to add an assert type(self.a) == expected_type here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, yes, this will be more general to be able to wrap subclasses and coerce to them.
…for e.g.AsyncCollectiveTensor" Based on discussion here: #138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
bdhirsh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ncCollectiveTensor (pytorch#139095) Based on discussion here: pytorch#138731 Introducing ability for subclass implement type convertion to expected_type. ``` def __coerce_same_metadata_as_tangent__( self, expected_metadata: Any, expected_type: Optional[Type] = None ): ``` Here if `expected_type=None` means `SubclassClass` is expected. E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case `expected_type=Tensor` will be called during runtime Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`. Pull Request resolved: pytorch#139095 Approved by: https://github.com/bdhirsh
…ncCollectiveTensor ghstack-source-id: c38ae0f Pull Request resolved: pytorch/pytorch#139095
Stack from ghstack (oldest at bottom):
Based on discussion here: #138731
Introducing ability for subclass implement type convertion to expected_type.
Here if
expected_type=NonemeansSubclassClassis expected.E.g. for
DTensorwe may find tangentAsyncCollectiveTensorwhere we expectedTensor- in this caseexpected_type=Tensorwill be called during runtimeAdding implementation to AsyncCollectiveTensor, that just triggers
wait().cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o