Skip to content

Conversation

@IvanKobzarev
Copy link
Contributor

@IvanKobzarev IvanKobzarev commented Oct 28, 2024

Stack from ghstack (oldest at bottom):

Reason:
Currently we have multiple traversals for tangents in runtime:

  • To check that types and structure are identical to what we guessed during tracing time
  • Coerce metadata
  • Coerce memory_format
  • Unwrap_tensor_subclass
    All of them are traversing tangents via tensor_flatten calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.

Preparing memory_format is optional (controlled by with_memory_format=True).

  1. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139068

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit f3db59e with merge base 87f1990 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

IvanKobzarev added a commit that referenced this pull request Oct 28, 2024
ghstack-source-id: 1fcf78c
Pull Request resolved: #139068
@IvanKobzarev IvanKobzarev added topic: not user facing topic category ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Oct 28, 2024

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]
@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 28, 2024

I see some test failures?


Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor Author

IvanKobzarev commented Oct 29, 2024

I see some test failures?

Yes, checking. Just missing parenthesis for skipIfTorchDynamo() :)


Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]
# Not checking equality of ref and x as Exception is expected

# Partially addresses https://github.com/pytorch/pytorch/issues/106457
@skipIfTorchDynamo()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds like prior to this PR, this test would work properly under dynamo, but now it does not. Why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm if there answer is because dynamo blows up when trying to run directly on the new custom schema objects that we branch on at runtime, then I agree a skip here seems fine (it is unnecessary to get dynamo working on that). But I'd like a comment next to this @Skip explaining exactly what we are not supporting in dynamo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an error on symbolic shapes guard verbose printing, that appeared after tangents processing change:

https://gist.github.com/IvanKobzarev/339f6b0b1465de56731cb6d6d14f2a9f

@unittest.skipIf(
not torch.distributed.is_available(), "test requires torch distributed"
)
@skipIfTorchDynamo()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: test_dtensor_compile.py is probably a better fit for this test:

(1) it's testing AsyncCollectiveTensor, which is more of a distributed concept

(2) then we won't need to worry about the skipIfTorchDynamo logic, since the tests in that file won't involve dynamo running on the AOTAutograd code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, moved to test_dtensor_compile

*,
count_symints: bool = True,
) -> List[Union[int, SubclassCreationMeta]]:
with_memory_format: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand why we want to sometimes not include memory_format when creating subclass meta? If there is a good reason for doing it sometimes and not others, a comment explaining exactly when it is / is not necessary would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main logic was to not add overhead on deducing memory_format.
This could also be especially painful if to call it during tracing on FakeTensors with symbolic shapes - memory format checks in my experience give hairy symbolic shapes checks on strides (divisibility, equal to 1, reminder equals 0 etc.).

We use create_subclass_meta for inputs, outputs (in collect_metadata_analysis). I have not seen any usage of memory_format for inputs?

If we need memory format for inputs and outputs too - we can make it non-optional.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's fair - we don't need the memory format info for inputs. Can you just mention that in a comment?

(
AOTDispatchAutograd.coerce_runtime_tangent(
flat_processed_tangents = list(
itertools.chain.from_iterable(
Copy link
Contributor

@bdhirsh bdhirsh Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you had a chance to benchmark if the runtime overhead here nets out to being faster/slower than the original code? (I'd imagine that merging the looping over tangents into a single loop would be faster, although I'm also not sure how fast itertools.chain.from_iterable is).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I measured itertools.chain.from_iterable vs sequential list.extend(), itertools.chain.from_iterable was insignificantly faster ( < 1%).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using updated version of profiling PR #136478

Got that processing runtime tangents for recursive TwoTensor did not change (the diff in measurement std)

average before: 76610ns
average after: 76800ns

This of course depends how expensive is tensor_flatten call for SubClass, for TwoTensor it is cheap :)

def process_runtime_tangent(x, meta: Union[PlainTensorMeta, SubclassCreationMeta]):
if not isinstance(x, torch.Tensor):
return x
return x, [x]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still trying to understand what the purpose of the second return argument of this function is. What do we need it for? (it looks like it's dropped in the outer-most call to process_runtime_tangents)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current logic on tangents is:

tangents = all_args[TB, TE]
traverse_tangents_tree_to_check_type(tangents)
all_args = [traverse_subclass_tangents_coerce_metadata(all_args[i]) where i in [TB, TE]]
all_args = [traverse_subclass_tangents_coerce_memory_format(all_args[i]) where i in [TB, TE]]]
all_args = traverse_subclass_unwrap(all_args)

We are fusing all traversals that check/update in process_runtime_tangents,
and also we fuse traverse_subclass_unwrap into process_runtime_tangents doing flatenning at the same time of checks/updates. The second argument returns updated flattened leaves for each tangent.

As a result we come to the logic with only one subclasses tree traversal on runtime tangents, using second element in tuple as a result of unwrap.

processed_tangents = process_runtime_tangents(all_args[TB, TE])
processed_tangents_leaves = list(itertools.chain_from_iterable(pt[1]) for pt in processed_tangents)
all_args = traverse_subclass_unwrap(all_args[:TB]) + processed_tangents_leaves +  traverse_subclass_unwrap(all_args[TE+1:]) 

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks mostly good - left a few comments


Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 31, 2024

return x
if is_traceable_wrapper_subclass(x):
runtime_meta = x.__tensor_flatten__()[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I see we're calling __tensor_flatten__() twice, to get the metadata here and the inner keys later. If you think we can easily get away with a single call that seems better, but if not that's ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Yes, originally I thought that we should call torch_flatten one more time after potential coercion (e.g. subclass type change) - I will make a check if x is unchanged - then we do not need extra tensor_flatten. But if coercion happened - than to call tensor_flatten.


Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]

Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.  

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.



cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@IvanKobzarev
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Reason:
Currently we have multiple traversals for tangents in runtime:
 - To check that types and structure are identical to what we guessed during tracing time
 - Coerce metadata
 - Coerce memory_format
 - Unwrap_tensor_subclass
All of them are traversing tangents via __tensor_flatten__ calls the tree of Subclasses.

Change:
To do everything in one traversal at runtime (including flattening)

Implementation details:

Add memory_format information inside SubclassCreationMeta, for PlainTensors keep not only (int) of unwrapped_index, but memory_format too.

Preparing memory_format is optional (controlled by with_memory_format=True).

2. Removing unused subclass_utils.create_metadata_for_subclass which does not have any usages inside torch and would require update of the logic.

Pull Request resolved: pytorch#139068
Approved by: https://github.com/bdhirsh
@github-actions github-actions bot deleted the gh/IvanKobzarev/80/head branch December 1, 2024 02:21
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: AO frontend topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants