Skip to content

Conversation

@anijain2305 anijain2305 requested a review from zou3519 as a code owner March 13, 2025 00:12
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149087

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 4386a6c with merge base ce54c43 (image):

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Mar 13, 2025
ghstack-source-id: 9c558f4
Pull Request resolved: #149087
@anijain2305 anijain2305 added the topic: not user facing topic category label Mar 18, 2025
anijain2305 added a commit that referenced this pull request Mar 18, 2025
ghstack-source-id: d8cff23
Pull Request resolved: #149087
registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {}


def register_hop_fake(hop, fn=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this just register_fake? the fqn already has HOP in it (torch._higher_order_ops.register_fake)


def __hash__(self):
return id(self.subgraph)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you add some sort of repr so we know what we're dealing with when we're debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
self._crosscheck_cache_output(output, func, types, args, kwargs)
with disable_fake_tensor_cache(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, when is self.cache_crosscheck_enabled set to True? I assume it is set to True somewhere in testing which is why this change was necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in test_fake_tensor.py with

torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True

# caching implementation, e.g., data dependent ops or ops that modify
# the inputs.
from torch._higher_order_ops.utils import registered_hop_fake_fns

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type of func is now wrong, need to update it to be a union

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I tried this earlier and this causes a bunch of mypy failures because now func is expecting tags and other attributes all over the place, which are not present in HigherOrderOperator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah don't worry about it

Comment on lines +1459 to +1460
isinstance(func, torch._ops.HigherOrderOperator)
and func in registered_hop_fake_fns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is entirely correct, if the result of the HOP has data-dependent output shape or dynamic output shape then we need to bail out (

if torch.Tag.data_dependent_output in func.tags:
raise _BypassDispatchCache("data dependent output")
if torch.Tag.dynamic_output_shape in func.tags:
raise _BypassDispatchCache("dynamic output shape")
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I remember why I needed a validator function. Ah.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding new logic that goes through the subgraph nodes and checks each of them.

Comment on lines 1536 to 1542
elif isinstance(arg, torch.fx.GraphModule):
# This is used for invoke_subgraph where id(graph_module) allows
# us to cache fake outputs
result.append(type(arg))
result.append(id(arg))
elif isinstance(arg, FunctionalizeCtxWrapper):
result.append(hash(arg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm, does id(arg) assume that the GraphModule stays alive forever? (What if the GraphModule gets deallocated and another one gets allocated in its stead?)

We might need some cache invalidation mechanism via weakref.finalize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it here - #149667

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we forgot to discuss the recursive case (we did talk about it half a year ago, and I am remembering it now): what happens if invoke_subgraph has a subgraph where there is an operator that is not eligible for FakeTensor caching?

We shouldn't allow that invoke_subgraph to be cached. There's probably an efficient strategy for checking this, like during FakeTensorProp for invoke_subgraph we should do the FakeTensorProp for the subgraph first and then if that didn't have any ineligible operators then we say the invoke_subgraph can be cached.

@anijain2305 anijain2305 requested a review from zou3519 March 20, 2025 21:13
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests failing

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #150036

Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
@github-actions github-actions bot deleted the gh/anijain2305/700/head branch May 2, 2025 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants