-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[DTensor] Check if tracing for sharding propagation to handle unhashable keys #160798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160798
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 900950b with merge base 5babb4d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| output_sharding = cast( | ||
| OutputSharding, self.propagate_op_sharding(op_info.schema) | ||
| ) | ||
| except TypeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your Pr desc mentions less clear error messages in dynamo, can you explain how that relates to the changes? It seems like this change should just cause more cases to miss cache and there would be no associated error.
Relatedly, if we silently miss cache for unhashable type errors, we'll not get the signal we need to fix those cases. Do you have a plan for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally we had landed #156868 to purposely skip the cache during compile, since there was the possibility of symints in the op schema, and existing has_symints would miss them if the schema was nested (e.g. an argument is a list of tensors). The goal was to avoid the additional overhead of doing a nested check and instead just disabling on compile.
However, that led to #159590, where dynamo is disabled for autograd with dynamic shapes so _are_we_tracing was not sufficient, and #159601 where the statefulness of _MaskPartial placement led to and issue with the propagate_op_sharding cache and the _gen_transform_infos not being aligned, so the mask materialization would not happen in the correct object.
So the additional cache misses are intentional. If there are errors from the caches, we should also expect to see them in the non-cached version. Else there is an error in functools, or the function we are using cacheing for mutates arguments, in which case we shouldn't be using cacheing for it. I do still feel a bit wary about try/except clauses in general, so we could additionaly filter the except on the exception message to only include "unshable type" in the error message, but that assumes the messsage will never change in functools.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed further with @bdhirsh and instead of the try/except, I changed it to go back to the compiled autograd flag, with an addition of the the PythonDispatcher being invoked in compile_inner, and also checked in _are_we_tracing.
aa44dcd to
bddd899
Compare
bddd899 to
262c424
Compare
torch/_dynamo/convert_frame.py
Outdated
| ) | ||
| stack.enter_context(CompileTimeInstructionCounter.record()) | ||
| return _compile_inner(code, one_graph, hooks, transform) | ||
| with enable_python_dispatcher(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep i would start out making this a bit lower blast radius - probably by moving it just around the code in dynamo that performs FakeTensor propagation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it to https://github.com/pytorch/pytorch/pull/160798/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R2147 and looks like it's resolved the issues. Does this look appropriate?
torch/_dynamo/variables/builder.py
Outdated
| with enable_python_dispatcher(): | ||
| example_value = wrap_to_fake_tensor_and_record( | ||
| value, tx=self.tx, is_tensor=True, source=source | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we'd need to cover every fake tensor tracing callsite with python dispatcher? If so, those are spread across a bit. You probably want to move the enable inside of wrap_to_fake_tensor_and_record as it is used a lot, and another one is wrap_fake_exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified it to wrap the entire of wrap_to_fake_tensor_and_record and for the fn call in wrap_fake_exception.
9a3ebcf to
6cb0630
Compare
…_we_tracing (#161334) Fixes an issue where the log softmax handler checked the tensor metadata cache without checking for tracing or symints. Probably best to merge this after #160798, but not strictly blocking. Pull Request resolved: #161334 Approved by: https://github.com/xmfan
| self.assertNotEqual(partial_placement1, partial_placement3) | ||
|
|
||
| @with_comms | ||
| def test_compile_embedding_redistribute(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we could probably do even better, but right now most of our compile + dtensor related tests live in test_dtensor_compile.py. I would probably vote to put the test there instead of in this file, so we can keep most of the eager-only DTensor test files as being actually eager-only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, just confirming - does this test fail if you remove the enable_python_dispatcher bits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point - I initially had correctness issues with the fake process group, but saw that TestDTensorCompileE2E had comms setup for the tests, so I moved it in there now.
This test is actually for the issue brought up in #159601, so not impacted by the enable_python_dispatcher changes. Add the _are_we_tracing() check to redistribute.py is what makes this pass.
| # If fake mode is turned on, we are almost definitely compiling/tracing. | ||
| if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: | ||
| return True | ||
| if torch._C._dispatch_tls_is_dispatch_key_included( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it was pretty non-obvious that we needed this for the dynamo case that you found. do you mind adding some comments so someone reading this code later can understand why we need it?
For this kind of situation I like add two comments that reference one another in the two different regions of code, so someone reading the code has some breadcrumbs:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g.
....in the region of dynamo where you enable_python_dispatcher
# Note [enable_python_dispatcher in dynamo]
# Dynamo disables itself when it runs fake tensor prop, which means that tensor subclasses
# have no way to know (purely based off of global state) if they are currently being run under compile or not.
# we use enable_python_dispatcher mainly to tweak the DispatchKeyState so that subclass authors
# can check it to know if they are running in an eager context or not
... in the _are_we_tracing function
# See Note [enable_python_dispatcher in dynamo]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the comments, appreciate it!
torch/_dynamo/variables/builder.py
Outdated
| symbolic_context=symbolic_context, | ||
| ) | ||
| ) | ||
| with enable_python_dispatcher(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there any functional changes to this file other than shifting everything here under the enable_python_dispatcher block? (if so, it would be easier to review if you can keep the rest of the code unchanged so the PR diff only shows up as an indent)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it looks like the linter adds some of these changes. I'll check if there's smaller sections to enable the dispatcher as well as clean this up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured it'd actually much cleaner to use a decorator, so I added wrap_python_dispatch. Not sure if there's an existing naming precedent for such a decorator though, so let me know if it needs to be changed up anyway.
acb96a7 to
13f3905
Compare
torch/_dynamo/utils.py
Outdated
| try: | ||
| return fn() | ||
| with enable_python_dispatcher(): | ||
| return fn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment I have is that:
(1) this enables enable_python_dispatcher for FakeTensorProp in dynamo
(2) technically, the "bad" case we saw only happened in the dynamo code that converted a graph input tensor into a real tensor. (the call to fake_mode.from_tensor(...))
Converting (real tensor) graph inputs into FakeTensors happens once for every input, while FakeTensorProp happens for every single tensor operation in the graph, so (1) is much more common than (2). We know that historically, FakeTensorProp takes a meaningful amount of compile time (we introduced caching inside of FakeTensorProp a while ago to help). So there is some risk of enable_python_dispatcher making FakeTensorProp in dynamo (and thus compile times) a bit slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh I see, so wrap_fake_exception is used in FakeTensorProp? In that case is wrap_to_fake_tensor_and_record used for FakeTensorProp? I think I'm lacking some context to figure out where the right place for the dispatcher is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline - we found a way to only turn on the python dispatcher for calls to fake_mode.from_tensor, and not during all of fake tensor prop
2b96f30 to
ea8f054
Compare
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
|
Previous merge failed due to a 505 error when download mnist. Retrying |
|
@pytorchbot merge -r |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…_we_tracing (pytorch#161334) Fixes an issue where the log softmax handler checked the tensor metadata cache without checking for tracing or symints. Probably best to merge this after pytorch#160798, but not strictly blocking. Pull Request resolved: pytorch#161334 Approved by: https://github.com/xmfan
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
_are_we_tracing appears to consume over 10% of overall CPU cost in DTensor.detach, using essentially the benchmark from #160580 |
|
@swolchok I'll give this a look again today. Is this in cumulative time, or just within |
correct, it's cumulative time. the breakdown looks roughly as follows: (disclaimer: the relevant denominator here might not be 100%, reading profiles is a bit of an art. the top-line number also appears to have moved slightly since the last time I looked) |
|
Unrelated to the above: tagging #163667 which properly wraps up cleanup |
|
@azahed98 btw I see that #156868 (comment) considered try/except. are the measurements there for the case where there are or are not symints caused by tracing? I would expect try/except to be the best choice for the non-tracing case, but I haven't actually had to look into the cost of Python try/except blocks when no exception is thrown. |
Yeah those measurements are for the no exception case. In that case it's roughly comparable to an if statement. If an exception is raised, the try/except adds a couple orders of magnitude extra overhead, but this would only happen during compile w/ dynamic shapes, so other than increasing compile time it's not the biggest problem. I'm open to swapping to try/except, but there was some inital pushback on it so we should discuss further if we want to go that way. |
oh, that's disappointing, I had hoped that was the cost when exceptions were thrown but that was too optimistic. |
|
It could be worth double checking with more robust profiling. Were the numbers your reported from cProfile? |
|
No, I'm using Linux perf with PYTHON_PERF_JIT_SUPPORT and python 3.13. |
…ble keys (pytorch#160798) Fixes pytorch#159590 This is similar to the reverted commit pytorch#156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e. `_MaskPartial`) as in issue pytorch#159601. This adds little to no overhead in eager ([see past benchmarks](pytorch#156868 (comment))). This also handles cases such as pytorch#159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op. Pull Request resolved: pytorch#160798 Approved by: https://github.com/bdhirsh
Fixes #159590
This is similar to the reverted commit #156868, except it resolves an issue with two caches becoming misaligned, leading to incorrect objects for stateful placements (i.e.
_MaskPartial) as in issue #159601. This adds little to no overhead in eager (see past benchmarks).This also handles cases such as #159590 where dynamo is disabled during tracing by entering the Python Dispatcher ahead of the sharding propogation during compile. Tests are added/modified to handle these, and the list/tuple inputs with the cat op.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela @xmfan