-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Package API for torch.compile #147528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Package API for torch.compile #147528
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147528
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New Failures, 3 Unrelated FailuresAs of commit 54bc977 with merge base ae29f05 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
48dfe0c to
bed234e
Compare
torch/_dynamo/sticky_cache.py
Outdated
| name = next( | ||
| n for n in self.dynamo_code.co_names if n.startswith("__compiled_fn") | ||
| ) | ||
| return types.FunctionType(self.dynamo_code, globals={name: self.aoti})( | ||
| *args, **kwargs | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong.
Dynamo code might depend on other globals in addition to __compiled_fn0, both in user code and in generated code. I think we need to examine the co_names of the code object.
This will also be incorrect if you compile two functions in the same file since the second will be __compiled_fn1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will also be incorrect if you compile two functions in the same file since the second will be __compiled_fn1
I'm not sure I follow this. If we fullgraph compile, shouldn't there be only 1 compiled fn mapped to dynamo code? If we mean globally here, then I expect each compiled object only reference 1 compiled function starts with "__compiled_fn", that's why we filter the co_names from line 93.
I think we need to examine the co_names of the code object.
What do you mean by "examine the co_names"? Should we raise an error if we see extra global names in co_names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess an example would be helpful and will be happy to add it to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can fullgraph compile two different functions in the same file (or two different shapes for the same functions).
For other globals, put a tensor in global scope and read from it.
|
@zhxchen17 @jansel I have a bit of a meta point, can we NOT call the public API a cache? We have term overload where people tend to come back saying "cache is not working", but there's always ambiguity, is it the the dynamo recompilation cache? Is it inductor/aotautograd cache? and now is it the sticky cache? we already have I feel like there will be a good value in clear disambiguation. |
@oulgen I got your point. Do you like the naming "persistent_artifacts" better than sticky_cache? I can switch name to avoid the confusion here. |
yes, that sounds a lot better, thanks |
bed234e to
3418d0d
Compare
3418d0d to
44ad78e
Compare
torch/_functorch/aot_autograd.py
Outdated
|
|
||
| if aot_config.sticky_cache is not None: | ||
| if any( | ||
| info.mutation_type != MutationType.NOT_MUTATED |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me know if you want more review on the AOTAutograd bits (we might want to think more about how this interacts with the existing AOT warm cache today).
One comment is that input mutations are probably fine for sticky cache, as long as the input mutation is captured inside of the graph (the "bad" case is if the input mutation is forced to run outside of the graph, in an opaque AOTAutograd epilogue). Almost all input mutations at this point are captured in-graph today
697d678 to
fc4aff3
Compare
c73f996 to
542455a
Compare
546f276 to
536abe3
Compare
84ba2e9 to
eb8b103
Compare
|
@zhxchen17 I think it might be good to discuss the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool! Some initial questions about code organization, at least in the higher level sections I understand haha
|
|
||
| class _GraphCompile: | ||
| """ | ||
| Stores the compiled artifacts per compiled FX graph. This includes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not exactly accurate, is it? It stores the compiled artifacts for N compiled FXGraphs, not just one, unless you mean the input function, which itself is not an FXGraph?
|
|
||
| def __init__(self, name: str) -> None: | ||
| self.name = name | ||
| self.forward_aoti: Optional[CompiledAOTI] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my head, what would make most sense is that an AOTAutogradCacheEntry contains a forward/backward_aoti. I know this is a prototype so you probably didn't want to change AOTAutogradCacheEntry, but I think what you'd actually want is:
-
Refactor AOTAutogradCacheEntry into an interface witha forward/backward OutputCode + the aot autograd wrappers. You could make it generic on the OutputCode type, i.e. AOTAutogradCacheEntry[CompiledAOTI] vs. AOTAutogradCacheEntry[CompiledFXGraph], etc.
-
Split AOTAutogradCacheEntry into the cache path, where there's a python wrapper + FXGraphCache keys, vs. the package API, which uses AOTI artifacts
Then here, you would just store self.aot_autograd_output : AOTAutogradCacheEntry[CompiledAOTI].
| aot_config = pickle.load(f) | ||
|
|
||
| def _(*args, **kwargs): | ||
| raise RuntimeError("NYI") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're never gonna be able to serialize these functions, so they're not NYI, right? They're actually full forward/backward compilers(often inductor libraries). Instead I think you want to ignore these callables (or assume that they are always the same)
| # second half, and serialier should read data from the return | ||
| # value of the first half. | ||
| # For prototyping we just leave this as a separate class. | ||
| class GuardSerializer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we think about moving the guard serde into its own PR? That way you could test it separately, and have it be generic, rather than specific to the package API
| value_len, get_verbose_code_parts(code, guard) | ||
| ) | ||
|
|
||
| def HASATTR(self, guard, metadata): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the protocol for making sure people don't add new guards without adding this?
Would it be simpler not to have each Guard have a ser/deser method? That way, when people add new Guard types, they would have to implement serde directly, instead of finding a separate class for them.
| load_paths = glob.glob(os.path.join(load_dir, "*.dynamo_code")) | ||
| self.assertEqual(len(load_paths), 0) | ||
|
|
||
| compiled_fn.save_package() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not take the path here? Seems a bit weird to have the path be an arg to torch.compile (which won't do anything unless you call save).
test/test_compile_package.py
Outdated
|
|
||
| precompiles = [] | ||
| for i in range(len(load_paths)): | ||
| precompile = torch._dynamo.compile_package._load_precompile(load_dir, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different that compiled_fn.load_package?
| with self.assertRaisesRegex( | ||
| RuntimeError, "Compile package is only supported .*fullgraph=True" | ||
| ): | ||
| torch.compile(f, dynamic=False, package=self.path()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about how package= changes the semantics of torch.compile? Ideally there would be no semantic differences (in which case the path could just be an arg to save/load_package).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed yesterday, there'll always be corner cases compile package doesn't work (e.g. unpicklable local classes), so we need something to make torch.compile() start to build the package only when asked to.
That being said, we can also provide a global boolean flag instead. (Might be easier for batch testing anyways)
| return a | ||
|
|
||
| with self.assertRaisesRegex( | ||
| RuntimeError, "Compile package is only supported .*fullgraph=True" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to confirm you plan to remove the limitation, since a lot of key users of this have graph breaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel yeah, I intend to remove this limitation very soon
test/test_compile_package.py
Outdated
|
|
||
| with self.assertRaisesRegex( | ||
| NotImplementedError, | ||
| "dynamic shape", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is planned correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep
| f = torch.compile(f, fullgraph=True, dynamic=False, package=self.path()) | ||
| with self.assertRaisesRegex( | ||
| NotImplementedError, | ||
| "backward", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is planned correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, planned
| from torch._dynamo.exc import InternalTorchDynamoError | ||
|
|
||
| with self.assertRaisesRegex( | ||
| InternalTorchDynamoError, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like the wrong exception type.
| self._test_load(f, args3, expected3, self.path()) | ||
|
|
||
| # self._test_save(f, (torch.randn(4, 2), torch.randn(4, 3)), None, self.path(), None) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test for trying to load a saved package generated for function foo on function bar.
- If the name/signature don't match, we should throw an error on load
- We should also recursively hash the source code and assert the source codes match. (With a flag to disable this check)
| self.apply_options({"cpp_wrapper": True}) | ||
| self.apply_options({"aot_inductor.package": True}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the initial version of this should use the Python wrapper code. I think this will be required (at least in the short term) to support:
- Training
- Dynamic shapes
- Graph breaks
I think the use case of single-graph, static-shape, inference is already well covered by export+AOTI and doesn't really need a new API.
We should be able to reuse our existing caching infra for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we refactor the package code with AOTAutogradCacheEntry properly, we can combine these designs nicely into a single framework:
- Caching creates an AOTAutogradCacheEntry with Python wrapper cache entries
- Precompile creates an AOTAutogradCacheEntry with either CPP wrapper or python wrapper cache entries, + a set of serialized dynamo guards.
The nice thing about this is, if we do it right, we also get regular caching support for cpp wrapper, which we actually don't have today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel I'm all for training as well. Just disagree with this
inference is already well covered by export+AOTI and doesn't really need a new API.
inference is not well covered by export+AOTI since the UX is still broken for compiling partial graphs (wrote a post about this https://fb.workplace.com/groups/257735836456307/permalink/802058152024070/)
I'd like to still take inference into account, and as @jamesjwu mentioned, we can leverage caching to support both cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was talking about "single-graph, static-shape, inference" (all three at the same time) being well covered.
Following up PR #145381, we implement a new API for compiling models using the cpp wrapper, and save/load compiled artifacts to disk. Package is now designed to be a per-torch.compile() object living with the compilation context. Each time a recompilation happens, it will collect the compiled artifacts into a lookup table. When a new set of inputs is passed to the compiled callable, before we enter the dynamo cache, we will perform a lookup in compile package first, and match by the serialized guards. API names are tentative but the workflow roughly looks like the following: ``` def f(...): ... compiled_f = torch.compile(f, package="my_dir/my_model") compiled_f(*args) compiled_f.save_package(prefix="/dir1") ... compiled_f.load_package(prefix="/dir2") ```
eb8b103 to
54bc977
Compare
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Is this still alive? |
Package API for torch.compile
Following up PR #145381, we implement
a new API for compiling models using the cpp wrapper, and save/load
compiled artifacts to disk.
Package is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in compile package first, and match by the serialized guards.
API names are tentative but the workflow roughly looks like the following:
Fixes #ISSUE_NUMBER
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov