Skip to content

Conversation

@azahed98
Copy link
Contributor

@azahed98 azahed98 commented Aug 16, 2025

Fixes #159590

This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. _MaskPartial) as in issue #159601. This adds little to no overhead in eager (see past benchmarks).

This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @xmfan

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160798

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 900950b with merge base 5babb4d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 16, 2025
@azahed98 azahed98 added release notes: distributed (dtensor) release notes category and removed oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Aug 16, 2025
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 18, 2025
@azahed98 azahed98 marked this pull request as ready for review August 18, 2025 07:01
output_sharding = cast(
OutputSharding, self.propagate_op_sharding(op_info.schema)
)
except TypeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your Pr desc mentions less clear error messages in dynamo, can you explain how that relates to the changes? It seems like this change should just cause more cases to miss cache and there would be no associated error.

Relatedly, if we silently miss cache for unhashable type errors, we'll not get the signal we need to fix those cases. Do you have a plan for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally we had landed #156868 to purposely skip the cache during compile, since there was the possibility of symints in the op schema, and existing has_symints would miss them if the schema was nested (e.g. an argument is a list of tensors). The goal was to avoid the additional overhead of doing a nested check and instead just disabling on compile.

However, that led to #159590, where dynamo is disabled for autograd with dynamic shapes so _are_we_tracing was not sufficient, and #159601 where the statefulness of _MaskPartial placement led to and issue with the propagate_op_sharding cache and the _gen_transform_infos not being aligned, so the mask materialization would not happen in the correct object.

So the additional cache misses are intentional. If there are errors from the caches, we should also expect to see them in the non-cached version. Else there is an error in functools, or the function we are using cacheing for mutates arguments, in which case we shouldn't be using cacheing for it. I do still feel a bit wary about try/except clauses in general, so we could additionaly filter the except on the exception message to only include "unshable type" in the error message, but that assumes the messsage will never change in functools.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed further with @bdhirsh and instead of the try/except, I changed it to go back to the compiled autograd flag, with an addition of the the PythonDispatcher being invoked in compile_inner, and also checked in _are_we_tracing.

@azahed98 azahed98 force-pushed the feat/sharding_prop_compile branch from aa44dcd to bddd899 Compare August 19, 2025 20:24
@azahed98 azahed98 force-pushed the feat/sharding_prop_compile branch from bddd899 to 262c424 Compare August 19, 2025 20:27
)
stack.enter_context(CompileTimeInstructionCounter.record())
return _compile_inner(code, one_graph, hooks, transform)
with enable_python_dispatcher():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep i would start out making this a bit lower blast radius - probably by moving it just around the code in dynamo that performs FakeTensor propagation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@azahed98 azahed98 changed the title [DTensor] Use try/except for sharding propagation to handle unhashable keys [DTensor] Check if tracing for sharding propagation to handle unhashable keys Aug 22, 2025
Comment on lines 2147 to 2150
with enable_python_dispatcher():
example_value = wrap_to_fake_tensor_and_record(
value, tx=self.tx, is_tensor=True, source=source
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we'd need to cover every fake tensor tracing callsite with python dispatcher? If so, those are spread across a bit. You probably want to move the enable inside of wrap_to_fake_tensor_and_record as it is used a lot, and another one is wrap_fake_exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified it to wrap the entire of wrap_to_fake_tensor_and_record and for the fn call in wrap_fake_exception.

@azahed98 azahed98 force-pushed the feat/sharding_prop_compile branch from 9a3ebcf to 6cb0630 Compare August 25, 2025 23:56
pytorchmergebot pushed a commit that referenced this pull request Aug 26, 2025
…_we_tracing (#161334)

Fixes an issue where the log softmax handler checked the tensor metadata cache without checking for tracing or symints.

Probably best to merge this after #160798, but not strictly blocking.

Pull Request resolved: #161334
Approved by: https://github.com/xmfan
self.assertNotEqual(partial_placement1, partial_placement3)

@with_comms
def test_compile_embedding_redistribute(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could probably do even better, but right now most of our compile + dtensor related tests live in test_dtensor_compile.py. I would probably vote to put the test there instead of in this file, so we can keep most of the eager-only DTensor test files as being actually eager-only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, just confirming - does this test fail if you remove the enable_python_dispatcher bits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - I initially had correctness issues with the fake process group, but saw that TestDTensorCompileE2E had comms setup for the tests, so I moved it in there now.

This test is actually for the issue brought up in #159601, so not impacted by the enable_python_dispatcher changes. Add the _are_we_tracing() check to redistribute.py is what makes this pass.

# If fake mode is turned on, we are almost definitely compiling/tracing.
if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None:
return True
if torch._C._dispatch_tls_is_dispatch_key_included(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it was pretty non-obvious that we needed this for the dynamo case that you found. do you mind adding some comments so someone reading this code later can understand why we need it?

For this kind of situation I like add two comments that reference one another in the two different regions of code, so someone reading the code has some breadcrumbs:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.

....in the region of dynamo where you enable_python_dispatcher
# Note [enable_python_dispatcher in dynamo]
# Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
# have no way to know (purely based off of global state) if they are currently being run under compile or not.
# we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
# can check it to know if they are running in an eager context or not

... in the _are_we_tracing function
# See Note [enable_python_dispatcher in dynamo]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comments, appreciate it!

symbolic_context=symbolic_context,
)
)
with enable_python_dispatcher():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any functional changes to this file other than shifting everything here under the enable_python_dispatcher block? (if so, it would be easier to review if you can keep the rest of the code unchanged so the PR diff only shows up as an indent)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it looks like the linter adds some of these changes. I'll check if there's smaller sections to enable the dispatcher as well as clean this up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it'd actually much cleaner to use a decorator, so I added wrap_python_dispatch. Not sure if there's an existing naming precedent for such a decorator though, so let me know if it needs to be changed up anyway.

@azahed98 azahed98 force-pushed the feat/sharding_prop_compile branch from acb96a7 to 13f3905 Compare August 28, 2025 04:53
try:
return fn()
with enable_python_dispatcher():
return fn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment I have is that:

(1) this enables enable_python_dispatcher for FakeTensorProp in dynamo

(2) technically, the "bad" case we saw only happened in the dynamo code that converted a graph input tensor into a real tensor. (the call to fake_mode.from_tensor(...))

Converting (real tensor) graph inputs into FakeTensors happens once for every input, while FakeTensorProp happens for every single tensor operation in the graph, so (1) is much more common than (2). We know that historically, FakeTensorProp takes a meaningful amount of compile time (we introduced caching inside of FakeTensorProp a while ago to help). So there is some risk of enable_python_dispatcher making FakeTensorProp in dynamo (and thus compile times) a bit slower.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh I see, so wrap_fake_exception is used in FakeTensorProp? In that case is wrap_to_fake_tensor_and_record used for FakeTensorProp? I think I'm lacking some context to figure out where the right place for the dispatcher is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline - we found a way to only turn on the python dispatcher for calls to fake_mode.from_tensor, and not during all of fake tensor prop

@azahed98 azahed98 force-pushed the feat/sharding_prop_compile branch from 2b96f30 to ea8f054 Compare September 5, 2025 05:45
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@azahed98
Copy link
Contributor Author

azahed98 commented Sep 9, 2025

Previous merge failed due to a 505 error when download mnist. Retrying

@azahed98
Copy link
Contributor Author

azahed98 commented Sep 9, 2025

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #160798, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@azahed98
Copy link
Contributor Author

azahed98 commented Sep 9, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…_we_tracing (pytorch#161334)

Fixes an issue where the log softmax handler checked the tensor metadata cache without checking for tracing or symints.

Probably best to merge this after pytorch#160798, but not strictly blocking.

Pull Request resolved: pytorch#161334
Approved by: https://github.com/xmfan
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
@swolchok
Copy link
Contributor

This adds little to no overhead in eager (#156868 (comment)).

_are_we_tracing appears to consume over 10% of overall CPU cost in DTensor.detach, using essentially the benchmark from #160580

@azahed98
Copy link
Contributor Author

@swolchok I'll give this a look again today. Is this in cumulative time, or just within _are_we_tracing and not deeper down the stack? The logic in _are_we_tracing itself is really simple (could probably micro-optimize with avoiding if statements, but that's relatively minor), so I'm guessing most of this time is spent in checking dispatch mode.

@swolchok
Copy link
Contributor

swolchok commented Sep 23, 2025

@swolchok I'll give this a look again today. Is this in cumulative time, or just within _are_we_tracing and not deeper down the stack? The logic in _are_we_tracing itself is really simple (could probably micro-optimize with avoiding if statements, but that's relatively minor), so I'm guessing most of this time is spent in checking dispatch mode.

correct, it's cumulative time. the breakdown looks roughly as follows: (disclaimer: the relevant denominator here might not be 100%, reading profiles is a bit of an art. the top-line number also appears to have moved slightly since the last time I looked)

8.6% _are_we_tracing
  0.7% torch::autograd::get_dispatch_mode
  0.7% some other dispatch thing that is mostly pybind11 overhead
  0.44% PyObject_GetAttr
  5.90% get_proxy_mode
    0.5% torch::autograd::get_dispatch_mode again
    4.36% _get_dispatch_mode_pre_dispatch
      2.78% pybind enum comparison being slow (partially fixable)
      0.24% _ModeStackStateForPreDispatch.get
      0.16% mode_stack_state_for_pre_dispatch

@azahed98
Copy link
Contributor Author

azahed98 commented Sep 23, 2025

Unrelated to the above: tagging #163667 which properly wraps up cleanup has_symints, which this PR removed the need for. Commenting here for future reference so I know how to cleanup unused fields

@swolchok
Copy link
Contributor

@azahed98 btw I see that #156868 (comment) considered try/except. are the measurements there for the case where there are or are not symints caused by tracing? I would expect try/except to be the best choice for the non-tracing case, but I haven't actually had to look into the cost of Python try/except blocks when no exception is thrown.

@azahed98
Copy link
Contributor Author

are the measurements there for the case where there are or are not symints caused by tracing?

Yeah those measurements are for the no exception case. In that case it's roughly comparable to an if statement. If an exception is raised, the try/except adds a couple orders of magnitude extra overhead, but this would only happen during compile w/ dynamic shapes, so other than increasing compile time it's not the biggest problem.

I'm open to swapping to try/except, but there was some inital pushback on it so we should discuss further if we want to go that way.

@swolchok
Copy link
Contributor

Yeah those measurements are for the no exception case. In that case it's roughly comparable to an if statement. If an exception is raised, the try/except adds a couple orders of magnitude extra overhead,

oh, that's disappointing, I had hoped that was the cost when exceptions were thrown but that was too optimistic.

@azahed98
Copy link
Contributor Author

It could be worth double checking with more robust profiling. Were the numbers your reported from cProfile?

@swolchok
Copy link
Contributor

No, I'm using Linux perf with PYTHON_PERF_JIT_SUPPORT and python 3.13.

dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ble keys (pytorch#160798)

Fixes pytorch#159590

This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))).

This also handles cases such as pytorch#159590  where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.

Pull Request resolved: pytorch#160798
Approved by: https://github.com/bdhirsh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: compiled autograd compiled_autograd module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DTensor Compile w/ Dynamic Shapes Autograd - Unhashable SymInt in sharding propagation when inputs have requires_grad=True

6 participants