Skip to content

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Oct 28, 2024

Stack from ghstack (oldest at bottom):

Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.

    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):

Here if expected_type=None means SubclassClass is expected.

E.g. for DTensor we may find tangent AsyncCollectiveTensor where we expected Tensor - in this case
expected_type=Tensor will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers wait().

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139095

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f4d09da with merge base 5f266b5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Oct 28, 2024
@IvanKobzarev IvanKobzarev added the topic: not user facing topic category label Oct 28, 2024
@IvanKobzarev IvanKobzarev requested a review from awgu October 28, 2024 18:38
if is_subclass and not is_subclass_meta:
# Unexpected subclass, during tracing we guessed it was a plain Tensor
if hasattr(x, "__coerce_same_metadata_as_tangent__"):
x = x.__coerce_same_metadata_as_tangent__(None, torch.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. Two things:

(1) Right now this branch is hardcoded for the case where we expected a subclass tangent but got a plain tensor tangent. But we specifically updated the coerce API to be more generic: it can technically allow a subclass to convert to any other subclass type, if they are able to handle it. It seems to me like if we are going with that more general API, we should properly handle that here: don't haredcode torch.Tensor, just directly pass in the type(x).

Since we're potentially being BC breaking (I think only DTensor uses this API, although I think a few people have forked DTensor out of tree over time), it might also be better to optionally only pass the type argument in when the types are different

(2) you added this call as a completely new one, on top of the existing a call below (x = x.__coerce_same_metadata_as_tangent__(meta.meta)). It seems better to consolidate them into a single call?

…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 28, 2024
…ncCollectiveTensor

ghstack-source-id: 91f570a
Pull Request resolved: #139095

def __coerce_same_metadata_as_tangent__(self, flatten_spec):
def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None):
assert expected_type is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this can be user facing, if the user ends up running their code in such a way that e.g. the expected tangents are DTensors but the actual tangents are plain tensors. So we should make sure the error message if this assert fails is very clear. (include the actual/expected types and metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or alternatively, we could have this function return None if it is unable to properly coerce, and let AOTAutograd raise the error itself if it sees a None return

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with current logic if it returns None we will raise an error with all details. https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1486
I think we should also document the logic of coercing by metadata and type, the meaning of None etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm this code as-written (the assert above) still seem like it will regress our error message?

Prior to this PR, if we expected a DTensor tangent but we got a plain tensor tangent, we would get this nice error message: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1505

Seems like you want to return None here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, better to not assert and just return None here for user-friendly error message.

raise RuntimeError("Not implemented")

t = self.trigger_wait()
while isinstance(t, AsyncCollectiveTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we shouldn't need this while loop either (I think we will be a in a very weird place if there are every nested AsyncCollectiveTensors, so it seems pointless to defensively program for it)

def maybe_coerce_to_memory_format(t, memory_format):
if not t.is_contiguous(memory_format=meta.memory_format):
return t.contiguous(memory_format=meta.memory_format)
return t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

superficially from reading this function, it's not clear what the return type is supposed to be? (here you are returning a single tensor t, while lower down you return x, [x])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe_coerce_to_memory_format is only for memory_format processing of one argument.

The whole function process_runtime_tangent returns tuple to be able to do flatenning (2d item in tuple) at the same traversal as processing, to do only one traversal in total.
So the return type is Tuple[ChangedRawItem, FlattenedChangedItems]


if not x.is_contiguous(memory_format=meta.memory_format):
x = x.contiguous(memory_format=meta.memory_format)
if is_subclass and not is_subclass_meta:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, concretely, I think:

(1) instead of special-casing the is_subclass case, we should just check if the type of the two tensors is different, and if so then unconditionally call coerce(metadata or None, type(actual))

(2) if the coerce() function returns None, the subclass has indicated that it cannot perform the coercion, and so we can raise the old error message (with even more useful information in the error than the subclass could have given, like exactly which tangent in the tangent list we are up to)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Fused the logic in one coercion and updated PR.

…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Oct 31, 2024
…ncCollectiveTensor

ghstack-source-id: e157699
Pull Request resolved: #139095
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Nov 1, 2024
…ncCollectiveTensor

ghstack-source-id: f3b0273
Pull Request resolved: #139095
@IvanKobzarev IvanKobzarev requested a review from bdhirsh November 1, 2024 14:20
from torch.utils._python_dispatch import return_and_correct_aliasing


class WrapSC(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: up to you, but I find WrapperSubclass a bit clearer than WrapperSC

def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is None
a = inner_tensors["a"]
if type(a) is torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the purpose of this check is that you only want to run the assertions below if we are at runtime, not trace time.

This isn't very robust though, since it will fail if the inner tensor a is itself another subclass.

You probably want to use is_fake(a), to tell if you are in the middle of tracing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I just copied this from TwoTensor :)

self, expected_metadata: Any, expected_type: Optional[Type] = None
):
if expected_type is torch.Tensor:
return self.a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh this code as written is a bit confusing: the idea of this function is that it is supposed to enforce that the return type is the same as expected_type. You might want to add an assert type(self.a) == expected_type here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes, this will be more general to be able to wrap subclasses and coerce to them.

…for e.g.AsyncCollectiveTensor"


Based on discussion here: #138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case 
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.


cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Nov 6, 2024
…ncCollectiveTensor

ghstack-source-id: c6e0c58
Pull Request resolved: #139095
@IvanKobzarev IvanKobzarev requested a review from bdhirsh November 6, 2024 13:30
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…ncCollectiveTensor (pytorch#139095)

Based on discussion here: pytorch#138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.

Pull Request resolved: pytorch#139095
Approved by: https://github.com/bdhirsh
@github-actions github-actions bot deleted the gh/IvanKobzarev/81/head branch December 8, 2024 02:17
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants