Skip to content

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Feb 20, 2025

Package API for torch.compile

Following up PR #145381, we implement
a new API for compiling models using the cpp wrapper, and save/load
compiled artifacts to disk.

Package is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in compile package first, and match by the serialized guards.

API names are tentative but the workflow roughly looks like the following:

def f(...): ...

compiled_f = torch.compile(f, package="my_dir/my_model")

compiled_f(*args)

compiled_f.save_package(prefix="/dir1")

...

compiled_f.load_package(prefix="/dir2")

Fixes #ISSUE_NUMBER

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147528

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures, 3 Unrelated Failures

As of commit 54bc977 with merge base ae29f05 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines 92 to 100
name = next(
n for n in self.dynamo_code.co_names if n.startswith("__compiled_fn")
)
return types.FunctionType(self.dynamo_code, globals={name: self.aoti})(
*args, **kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong.

Dynamo code might depend on other globals in addition to __compiled_fn0, both in user code and in generated code. I think we need to examine the co_names of the code object.

This will also be incorrect if you compile two functions in the same file since the second will be __compiled_fn1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also be incorrect if you compile two functions in the same file since the second will be __compiled_fn1

I'm not sure I follow this. If we fullgraph compile, shouldn't there be only 1 compiled fn mapped to dynamo code? If we mean globally here, then I expect each compiled object only reference 1 compiled function starts with "__compiled_fn", that's why we filter the co_names from line 93.

I think we need to examine the co_names of the code object.

What do you mean by "examine the co_names"? Should we raise an error if we see extra global names in co_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess an example would be helpful and will be happy to add it to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can fullgraph compile two different functions in the same file (or two different shapes for the same functions).

For other globals, put a tensor in global scope and read from it.

@oulgen
Copy link
Contributor

oulgen commented Feb 22, 2025

@zhxchen17 @jansel I have a bit of a meta point, can we NOT call the public API a cache? We have term overload where people tend to come back saying "cache is not working", but there's always ambiguity, is it the the dynamo recompilation cache? Is it inductor/aotautograd cache? and now is it the sticky cache?

we already have torch.compiler.load_cache_artifacts and torch.compiler.save_cache_artifacts which makes this further confusing.

I feel like there will be a good value in clear disambiguation.

@zhxchen17
Copy link
Contributor Author

zhxchen17 commented Feb 23, 2025

@zhxchen17 @jansel I have a bit of a meta point, can we NOT call the public API a cache? We have term overload where people tend to come back saying "cache is not working", but there's always ambiguity, is it the the dynamo recompilation cache? Is it inductor/aotautograd cache? and now is it the sticky cache?

we already have torch.compiler.load_cache_artifacts and torch.compiler.save_cache_artifacts which makes this further confusing.

I feel like there will be a good value in clear disambiguation.

@oulgen I got your point. Do you like the naming "persistent_artifacts" better than sticky_cache? I can switch name to avoid the confusion here.

@oulgen
Copy link
Contributor

oulgen commented Feb 27, 2025

@zhxchen17 @jansel I have a bit of a meta point, can we NOT call the public API a cache? We have term overload where people tend to come back saying "cache is not working", but there's always ambiguity, is it the the dynamo recompilation cache? Is it inductor/aotautograd cache? and now is it the sticky cache?
we already have torch.compiler.load_cache_artifacts and torch.compiler.save_cache_artifacts which makes this further confusing.
I feel like there will be a good value in clear disambiguation.

@oulgen I got your point. Do you like the naming "persistent_artifacts" better than sticky_cache? I can switch name to avoid the confusion here.

yes, that sounds a lot better, thanks

@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch from bed234e to 3418d0d Compare February 28, 2025 15:54
@zhxchen17 zhxchen17 marked this pull request as draft February 28, 2025 15:55
@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch from 3418d0d to 44ad78e Compare February 28, 2025 16:25

if aot_config.sticky_cache is not None:
if any(
info.mutation_type != MutationType.NOT_MUTATED
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you want more review on the AOTAutograd bits (we might want to think more about how this interacts with the existing AOT warm cache today).

One comment is that input mutations are probably fine for sticky cache, as long as the input mutation is captured inside of the graph (the "bad" case is if the input mutation is forced to run outside of the graph, in an opaque AOTAutograd epilogue). Almost all input mutations at this point are captured in-graph today

@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch 5 times, most recently from 697d678 to fc4aff3 Compare March 11, 2025 04:40
@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch 2 times, most recently from c73f996 to 542455a Compare March 27, 2025 02:30
@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch 2 times, most recently from 546f276 to 536abe3 Compare March 27, 2025 04:07
@zhxchen17 zhxchen17 changed the title Sticky cache API for torch.compile Package API for torch.compile Mar 27, 2025
@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch 3 times, most recently from 84ba2e9 to eb8b103 Compare March 27, 2025 14:48
@oulgen
Copy link
Contributor

oulgen commented Mar 27, 2025

@zhxchen17 I think it might be good to discuss the @torch.compile(package="/package/path"), I suspect having the package name on the compile API is OK for most OSS jobs, but I suspect will make it difficult to apply it on mast jobs, so it might be worthwhile moving this to load/save_package calls

Copy link
Contributor

@jamesjwu jamesjwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool! Some initial questions about code organization, at least in the higher level sections I understand haha


class _GraphCompile:
"""
Stores the compiled artifacts per compiled FX graph. This includes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not exactly accurate, is it? It stores the compiled artifacts for N compiled FXGraphs, not just one, unless you mean the input function, which itself is not an FXGraph?


def __init__(self, name: str) -> None:
self.name = name
self.forward_aoti: Optional[CompiledAOTI] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my head, what would make most sense is that an AOTAutogradCacheEntry contains a forward/backward_aoti. I know this is a prototype so you probably didn't want to change AOTAutogradCacheEntry, but I think what you'd actually want is:

  • Refactor AOTAutogradCacheEntry into an interface witha forward/backward OutputCode + the aot autograd wrappers. You could make it generic on the OutputCode type, i.e. AOTAutogradCacheEntry[CompiledAOTI] vs. AOTAutogradCacheEntry[CompiledFXGraph], etc.

  • Split AOTAutogradCacheEntry into the cache path, where there's a python wrapper + FXGraphCache keys, vs. the package API, which uses AOTI artifacts

Then here, you would just store self.aot_autograd_output : AOTAutogradCacheEntry[CompiledAOTI].

aot_config = pickle.load(f)

def _(*args, **kwargs):
raise RuntimeError("NYI")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're never gonna be able to serialize these functions, so they're not NYI, right? They're actually full forward/backward compilers(often inductor libraries). Instead I think you want to ignore these callables (or assume that they are always the same)

# second half, and serialier should read data from the return
# value of the first half.
# For prototyping we just leave this as a separate class.
class GuardSerializer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about moving the guard serde into its own PR? That way you could test it separately, and have it be generic, rather than specific to the package API

value_len, get_verbose_code_parts(code, guard)
)

def HASATTR(self, guard, metadata):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the protocol for making sure people don't add new guards without adding this?

Would it be simpler not to have each Guard have a ser/deser method? That way, when people add new Guard types, they would have to implement serde directly, instead of finding a separate class for them.

load_paths = glob.glob(os.path.join(load_dir, "*.dynamo_code"))
self.assertEqual(len(load_paths), 0)

compiled_fn.save_package()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not take the path here? Seems a bit weird to have the path be an arg to torch.compile (which won't do anything unless you call save).


precompiles = []
for i in range(len(load_paths)):
precompile = torch._dynamo.compile_package._load_precompile(load_dir, i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different that compiled_fn.load_package?

with self.assertRaisesRegex(
RuntimeError, "Compile package is only supported .*fullgraph=True"
):
torch.compile(f, dynamic=False, package=self.path())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about how package= changes the semantics of torch.compile? Ideally there would be no semantic differences (in which case the path could just be an arg to save/load_package).

Copy link
Contributor Author

@zhxchen17 zhxchen17 Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed yesterday, there'll always be corner cases compile package doesn't work (e.g. unpicklable local classes), so we need something to make torch.compile() start to build the package only when asked to.

That being said, we can also provide a global boolean flag instead. (Might be easier for batch testing anyways)

return a

with self.assertRaisesRegex(
RuntimeError, "Compile package is only supported .*fullgraph=True"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to confirm you plan to remove the limitation, since a lot of key users of this have graph breaks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel yeah, I intend to remove this limitation very soon


with self.assertRaisesRegex(
NotImplementedError,
"dynamic shape",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is planned correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

f = torch.compile(f, fullgraph=True, dynamic=False, package=self.path())
with self.assertRaisesRegex(
NotImplementedError,
"backward",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is planned correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, planned

from torch._dynamo.exc import InternalTorchDynamoError

with self.assertRaisesRegex(
InternalTorchDynamoError,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the wrong exception type.

self._test_load(f, args3, expected3, self.path())

# self._test_save(f, (torch.randn(4, 2), torch.randn(4, 3)), None, self.path(), None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for trying to load a saved package generated for function foo on function bar.

  1. If the name/signature don't match, we should throw an error on load
  2. We should also recursively hash the source code and assert the source codes match. (With a flag to disable this check)

Comment on lines +2302 to +2303
self.apply_options({"cpp_wrapper": True})
self.apply_options({"aot_inductor.package": True})
Copy link
Contributor

@jansel jansel Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the initial version of this should use the Python wrapper code. I think this will be required (at least in the short term) to support:

  1. Training
  2. Dynamic shapes
  3. Graph breaks

I think the use case of single-graph, static-shape, inference is already well covered by export+AOTI and doesn't really need a new API.

We should be able to reuse our existing caching infra for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we refactor the package code with AOTAutogradCacheEntry properly, we can combine these designs nicely into a single framework:

  • Caching creates an AOTAutogradCacheEntry with Python wrapper cache entries
  • Precompile creates an AOTAutogradCacheEntry with either CPP wrapper or python wrapper cache entries, + a set of serialized dynamo guards.

The nice thing about this is, if we do it right, we also get regular caching support for cpp wrapper, which we actually don't have today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel I'm all for training as well. Just disagree with this

inference is already well covered by export+AOTI and doesn't really need a new API.

inference is not well covered by export+AOTI since the UX is still broken for compiling partial graphs (wrote a post about this https://fb.workplace.com/groups/257735836456307/permalink/802058152024070/)

I'd like to still take inference into account, and as @jamesjwu mentioned, we can leverage caching to support both cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was talking about "single-graph, static-shape, inference" (all three at the same time) being well covered.

Following up PR #145381, we implement
a new API for compiling models using the cpp wrapper, and save/load
compiled artifacts to disk.

Package is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in compile package first, and match by the serialized guards.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, package="my_dir/my_model")

compiled_f(*args)

compiled_f.save_package(prefix="/dir1")

...

compiled_f.load_package(prefix="/dir2")
```
@zhxchen17 zhxchen17 force-pushed the zhxchen17/sticky_cache/0 branch from eb8b103 to 54bc977 Compare March 28, 2025 18:49
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 27, 2025
@ezyang
Copy link
Contributor

ezyang commented Jun 6, 2025

Is this still alive?

@zhxchen17
Copy link
Contributor Author

Is this still alive?

@ezyang Moving to #155118.

will close this for now

@zhxchen17 zhxchen17 closed this Jun 9, 2025
@github-actions github-actions bot deleted the zhxchen17/sticky_cache/0 branch July 13, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants