-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[aotd] Fuse tangents subclasses runtime traversals #139068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139068
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (6 Unrelated Failures)As of commit f3db59e with merge base 87f1990 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
|
I see some test failures? |
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Yes, checking. Just missing parenthesis for |
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
| # Not checking equality of ref and x as Exception is expected | ||
|
|
||
| # Partially addresses https://github.com/pytorch/pytorch/issues/106457 | ||
| @skipIfTorchDynamo() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it sounds like prior to this PR, this test would work properly under dynamo, but now it does not. Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm if there answer is because dynamo blows up when trying to run directly on the new custom schema objects that we branch on at runtime, then I agree a skip here seems fine (it is unnecessary to get dynamo working on that). But I'd like a comment next to this @Skip explaining exactly what we are not supporting in dynamo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an error on symbolic shapes guard verbose printing, that appeared after tangents processing change:
test/functorch/test_aotdispatch.py
Outdated
| @unittest.skipIf( | ||
| not torch.distributed.is_available(), "test requires torch distributed" | ||
| ) | ||
| @skipIfTorchDynamo() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: test_dtensor_compile.py is probably a better fit for this test:
(1) it's testing AsyncCollectiveTensor, which is more of a distributed concept
(2) then we won't need to worry about the skipIfTorchDynamo logic, since the tests in that file won't involve dynamo running on the AOTAutograd code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, moved to test_dtensor_compile
| *, | ||
| count_symints: bool = True, | ||
| ) -> List[Union[int, SubclassCreationMeta]]: | ||
| with_memory_format: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand why we want to sometimes not include memory_format when creating subclass meta? If there is a good reason for doing it sometimes and not others, a comment explaining exactly when it is / is not necessary would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main logic was to not add overhead on deducing memory_format.
This could also be especially painful if to call it during tracing on FakeTensors with symbolic shapes - memory format checks in my experience give hairy symbolic shapes checks on strides (divisibility, equal to 1, reminder equals 0 etc.).
We use create_subclass_meta for inputs, outputs (in collect_metadata_analysis). I have not seen any usage of memory_format for inputs?
If we need memory format for inputs and outputs too - we can make it non-optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's fair - we don't need the memory format info for inputs. Can you just mention that in a comment?
| ( | ||
| AOTDispatchAutograd.coerce_runtime_tangent( | ||
| flat_processed_tangents = list( | ||
| itertools.chain.from_iterable( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you had a chance to benchmark if the runtime overhead here nets out to being faster/slower than the original code? (I'd imagine that merging the looping over tangents into a single loop would be faster, although I'm also not sure how fast itertools.chain.from_iterable is).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I measured itertools.chain.from_iterable vs sequential list.extend(), itertools.chain.from_iterable was insignificantly faster ( < 1%).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using updated version of profiling PR #136478
Got that processing runtime tangents for recursive TwoTensor did not change (the diff in measurement std)
average before: 76610ns
average after: 76800ns
This of course depends how expensive is tensor_flatten call for SubClass, for TwoTensor it is cheap :)
| def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]): | ||
| if not isinstance(x, torch.Tensor): | ||
| return x | ||
| return x, [x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still trying to understand what the purpose of the second return argument of this function is. What do we need it for? (it looks like it's dropped in the outer-most call to process_runtime_tangents)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current logic on tangents is:
tangents = all_args[TB, TE]
traverse_tangents_tree_to_check_type(tangents)
all_args = [traverse_subclass_tangents_coerce_metadata(all_args[i]) where i in [TB, TE]]
all_args = [traverse_subclass_tangents_coerce_memory_format(all_args[i]) where i in [TB, TE]]]
all_args = traverse_subclass_unwrap(all_args)
We are fusing all traversals that check/update in process_runtime_tangents,
and also we fuse traverse_subclass_unwrap into process_runtime_tangents doing flatenning at the same time of checks/updates. The second argument returns updated flattened leaves for each tangent.
As a result we come to the logic with only one subclasses tree traversal on runtime tangents, using second element in tuple as a result of unwrap.
processed_tangents = process_runtime_tangents(all_args[TB, TE])
processed_tangents_leaves = list(itertools.chain_from_iterable(pt[1]) for pt in processed_tangents)
all_args = traverse_subclass_unwrap(all_args[:TB]) + processed_tangents_leaves + traverse_subclass_unwrap(all_args[TE+1:])
bdhirsh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks mostly good - left a few comments
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. [ghstack-poisoned]
|
|
||
| return x | ||
| if is_traceable_wrapper_subclass(x): | ||
| runtime_meta = x.__tensor_flatten__()[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I see we're calling __tensor_flatten__() twice, to get the metadata here and the inner keys later. If you think we can easily get away with a single call that seems better, but if not that's ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Yes, originally I thought that we should call torch_flatten one more time after potential coercion (e.g. subclass type change) - I will make a check if x is unchanged - then we do not need extra tensor_flatten. But if coercion happened - than to call tensor_flatten.
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Reason: Currently we have multiple traversals for tangents in runtime: - To check that types and structure are identical to what we guessed during tracing time - Coerce metadata - Coerce memory_format - Unwrap_tensor_subclass All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses. Change: To do everything in one traversal at runtime (including flattening) Implementation details: Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too. Preparing memory_format is optional (controlled by with_memory_format=True). 2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic. Pull Request resolved: pytorch#139068 Approved by: https://github.com/bdhirsh
ghstack-source-id: a3c1b5d Pull Request resolved: pytorch/pytorch#139068
Stack from ghstack (oldest at bottom):
Reason:
Currently we have multiple traversals for tangents in runtime:
All of them are traversing tangents via tensor_flatten calls the tree of Subclasses.
Change:
To do everything in one traversal at runtime (including flattening)
Implementation details:
Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.
Preparing memory_format is optional (controlled by with_memory_format=True).
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o