Skip to content

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Jan 22, 2025

Summary:
Design doc: https://docs.google.com/document/d/1Z15cBBPjoZ7gH00TSgCdgaYko7a7Br-ERd3_hA-g2IU/edit?usp=sharing

In this diff we are trying to introduce some stateful API to enable a global mode which will force inductor to use AOTI as a backend. Different from PR #141700, we didn't try to populate the package file into caching system, instead we bypass caching to simplify the implementation in the current form.

Similar to PR #141700, I did a quick benchmark to the loading time and it looks like the following:

  • Precompile
buck run mode/opt scripts/zhxchen17:precompile
  • Load using cache:
time buck run mode/opt scripts/zhxchen17:precompile -- --loader cache

Output:

real    0m24.593s
user    0m59.342s
sys     0m17.201s
  • Load using load_fullgraph_package
time buck run mode/opt scripts/zhxchen17:precompile -- --loader precompile

Output:

real    0m10.907s
user    0m9.210s
sys     0m1.173s

Test Plan:
buck run mode/opt caffe2/test:test_export -- -r test_fullgraph_package_basic
_function

Differential Revision: D68459341

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145381

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit fdbbd81 with merge base f951d21 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68459341

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68459341

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 17:12 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 17:12 Inactive
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68459341

zhxchen17 added a commit to zhxchen17/pytorch that referenced this pull request Jan 22, 2025
…rch#145381)

Summary:
Pull Request resolved: pytorch#145381

In this diff we are trying to introduce some stateful API to enable "fullgraph_package" mode which will force inductor to use AOTI as a backend. Different from PR pytorch#141700, we didn't try to populate the package file into caching system, instead we bypass caching to simplify the implementation in the current form.

Similar to PR pytorch#141700, I did a quick benchmark to the loading time and it looks like the following:
- Precompile
```
buck run mode/opt scripts/zhxchen17:precompile
```
- Load using cache:
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader cache
```
Output:
```
real    0m24.593s
user    0m59.342s
sys     0m17.201s
```
- Load using load_fullgraph_package
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader precompile
```
Output:
```
real    0m10.907s
user    0m9.210s
sys     0m1.173s
```

Test Plan:
buck run mode/opt caffe2/test:test_export -- -r test_fullgraph_package_basic
_function

Differential Revision: D68459341
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:32 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:32 Inactive
@zhxchen17 zhxchen17 requested a review from ezyang January 22, 2025 19:03
@zhxchen17 zhxchen17 added the topic: not user facing topic category label Jan 22, 2025
zhxchen17 added a commit to zhxchen17/pytorch that referenced this pull request Jan 22, 2025
…rch#145381)

Summary:

In this diff we are trying to introduce some stateful API to enable "fullgraph_package" mode which will force inductor to use AOTI as a backend. Different from PR pytorch#141700, we didn't try to populate the package file into caching system, instead we bypass caching to simplify the implementation in the current form.

Similar to PR pytorch#141700, I did a quick benchmark to the loading time and it looks like the following:
- Precompile
```
buck run mode/opt scripts/zhxchen17:precompile
```
- Load using cache:
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader cache
```
Output:
```
real    0m24.593s
user    0m59.342s
sys     0m17.201s
```
- Load using load_fullgraph_package
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader precompile
```
Output:
```
real    0m10.907s
user    0m9.210s
sys     0m1.173s
```

Test Plan:
buck run mode/opt caffe2/test:test_export -- -r test_fullgraph_package_basic
_function

Differential Revision: D68459341
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68459341

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 23:38 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 23:38 Inactive
zhxchen17 added a commit to zhxchen17/pytorch that referenced this pull request Jan 23, 2025
…rch#145381)

Summary:

In this diff we are trying to introduce some stateful API to enable "fullgraph_package" mode which will force inductor to use AOTI as a backend. Different from PR pytorch#141700, we didn't try to populate the package file into caching system, instead we bypass caching to simplify the implementation in the current form.

Similar to PR pytorch#141700, I did a quick benchmark to the loading time and it looks like the following:
- Precompile
```
buck run mode/opt scripts/zhxchen17:precompile
```
- Load using cache:
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader cache
```
Output:
```
real    0m24.593s
user    0m59.342s
sys     0m17.201s
```
- Load using load_fullgraph_package
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader precompile
```
Output:
```
real    0m10.907s
user    0m9.210s
sys     0m1.173s
```

Test Plan:
buck run mode/opt caffe2/test:test_export -- -r test_fullgraph_package_basic
_function

Differential Revision: D68459341
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68459341

@zhxchen17 zhxchen17 requested review from jamesjwu and removed request for anijain2305 January 23, 2025 16:55
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 17:17 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 24, 2025 22:10 Inactive
Comment on lines +12 to +21
def setUp(self):
if not os.path.exists(os.path.expandvars("/tmp/torchinductor_$USER/")):
os.makedirs(os.path.expandvars("/tmp/torchinductor_$USER/"))

def tearDown(self):
super().tearDown()
pathlib.Path(self.path()).unlink(missing_ok=True)

def path(self):
return os.path.expandvars(f"/tmp/torchinductor_$USER/model_{self.id()}.pt2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use existing helper for cache dir:

def cache_dir() -> str:
cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
if cache_dir is None:
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
return cache_dir

There is also an inductor-specific TestCase base class that sets the cache dir to a temporary place with automatic cleanup.

return os.path.expandvars(f"/tmp/torchinductor_$USER/model_{self.id()}.pt2")

@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_fullgraph_package_basic_function(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some additional test cases. Test on CPU. Test training. Test errors (like wrong shapes passed). Etc.

mode=mode,
options=options,
disable=disable,
name=name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the name needed? It seems a bit clunky to specify both a path and a name.

def __init__(
self,
*,
path: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a required arg? If the user doesn't specify it, the semantics of a default path seem odd. Similar APIs like model.save() don't have a default path.


if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
backend = _TorchCompileInductorWrapper(mode, options, dynamic, fullgraph, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does inductor need to know about fullgraph mode?

)


_PRECOMPILES: Dict[str, List[Any]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not thread safe. I'd think we should be able to eliminate the name and store this on the torch.compile object (just need to thread a pointer to that object down into this function).

Comment on lines +657 to +658
if precompile := _get_precompile(graph_kwargs.get("name"), example_inputs):
return precompile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an odd place to check the cache. By this point we have already run dynamo and AOT Autograd, which incur a lot of compile time -- but then we just throw out the graph we worked so hard to generate. If we moved this check up to the object returned by torch.compile then we could get the compile time down close to zero.

current_callable, 1, graph.device_type
).run # type: ignore[attr-defined]
)
elif graph.device_type == "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not tested?

current_callable, 1
).run # type: ignore[attr-defined]
)
elif graph.device_type == "xpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not tested?

else:
current_callable = compiled_fn

if graph.device_type.startswith("cuda"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a graph with both CPU and CUDA?

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 27, 2025 20:11 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 27, 2025 20:11 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 27, 2025 20:11 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 27, 2025 20:11 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 27, 2025 20:11 Inactive
# explicitly package precompiled artifacts into a single file.
# TODO Eventually we should come up with a context manager style API. To
# reduce the complexity of landing changes, we first introduce a set of
# stateful interfaces as the future building blocks to begin with.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be intended to be a real API right? Do you think it's too early to put our best foot forward on the API (which includes not having underscores?)

zhxchen17 added a commit that referenced this pull request Feb 20, 2025
Summary:

Design doc: https://docs.google.com/document/d/1Z15cBBPjoZ7gH00TSgCdgaYko7a7Br-ERd3_hA-g2IU/edit?usp=sharing

In this diff we are trying to introduce a new API pre torch.compile() object which will force inductor to use AOTI as a backend. Different from PR #141700.

Similar to PR #141700, I did a quick benchmark to the loading time and it looks like the following:
- Precompile
```
buck run mode/opt scripts/zhxchen17:precompile
```
- Load using cache:
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader cache
```
Output:
```
real    0m24.593s
user    0m59.342s
sys     0m17.201s
```
- Load using load_fullgraph_package
```
time buck run mode/opt scripts/zhxchen17:precompile -- --loader precompile
```
Output:
```
real    0m10.907s
user    0m9.210s
sys     0m1.173s
```

Test Plan:
buck run mode/opt caffe2/test:test_export -- -r test_fullgraph_package_basic
_function

Differential Revision: D68459341
zhxchen17 added a commit that referenced this pull request Feb 20, 2025
Following up PR #145381, we implement
a new API for compiling fullgraph models using the cpp wrapper, and save/load
compiled artifacts to disk.

Sticky cache is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in sticky cache first, and match by the guards on inputs
only.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, fullgraph=True, sticky_cache="my_dir/my_model")

compiled_f(*args)

compiled_f.save_sticky_cache(prefix="/dir1")

...

compiled_f.load_sticky_cache(prefix="/dir2")
```

Since this is touching many layers of the torch.compile system, we start from the
simple case of forward only graph, static shape and flat tensor inputs/outputs.
Once the overall API converges, we can gradually remove the sticky_cache.unimplemented()
calls from the code.
zhxchen17 added a commit that referenced this pull request Feb 21, 2025
Following up PR #145381, we implement
a new API for compiling fullgraph models using the cpp wrapper, and save/load
compiled artifacts to disk.

Sticky cache is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in sticky cache first, and match by the guards on inputs
only.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, fullgraph=True, sticky_cache="my_dir/my_model")

compiled_f(*args)

compiled_f.save_sticky_cache(prefix="/dir1")

...

compiled_f.load_sticky_cache(prefix="/dir2")
```

Since this is touching many layers of the torch.compile system, we start from the
simple case of forward only graph, static shape and flat tensor inputs/outputs.
Once the overall API converges, we can gradually remove the sticky_cache.unimplemented()
calls from the code.
@zhxchen17
Copy link
Contributor Author

Continue in #147528

@zhxchen17 zhxchen17 closed this Feb 22, 2025
zhxchen17 added a commit that referenced this pull request Feb 28, 2025
Following up PR #145381, we implement
a new API for compiling fullgraph models using the cpp wrapper, and save/load
compiled artifacts to disk.

Sticky cache is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in sticky cache first, and match by the guards on inputs
only.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, fullgraph=True, sticky_cache="my_dir/my_model")

compiled_f(*args)

compiled_f.save_sticky_cache(prefix="/dir1")

...

compiled_f.load_sticky_cache(prefix="/dir2")
```

Since this is touching many layers of the torch.compile system, we start from the
simple case of forward only graph, static shape and flat tensor inputs/outputs.
Once the overall API converges, we can gradually remove the sticky_cache.unimplemented()
calls from the code.
zhxchen17 added a commit that referenced this pull request Mar 8, 2025
Following up PR #145381, we implement
a new API for compiling fullgraph models using the cpp wrapper, and save/load
compiled artifacts to disk.

Sticky cache is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in sticky cache first, and match by the guards on inputs
only.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, fullgraph=True, sticky_cache="my_dir/my_model")

compiled_f(*args)

compiled_f.save_sticky_cache(prefix="/dir1")

...

compiled_f.load_sticky_cache(prefix="/dir2")
```

Since this is touching many layers of the torch.compile system, we start from the
simple case of forward only graph, static shape and flat tensor inputs/outputs.
Once the overall API converges, we can gradually remove the sticky_cache.unimplemented()
calls from the code.
zhxchen17 added a commit that referenced this pull request Mar 11, 2025
Following up PR #145381, we implement
a new API for compiling fullgraph models using the cpp wrapper, and save/load
compiled artifacts to disk.

Sticky cache is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in sticky cache first, and match by the guards on inputs
only.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, fullgraph=True, sticky_cache="my_dir/my_model")

compiled_f(*args)

compiled_f.save_sticky_cache(prefix="/dir1")

...

compiled_f.load_sticky_cache(prefix="/dir2")
```

Since this is touching many layers of the torch.compile system, we start from the
simple case of forward only graph, static shape and flat tensor inputs/outputs.
Once the overall API converges, we can gradually remove the sticky_cache.unimplemented()
calls from the code.
zhxchen17 added a commit that referenced this pull request Mar 27, 2025
Following up PR #145381, we implement
a new API for compiling models using the cpp wrapper, and save/load
compiled artifacts to disk.

Package is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in compile package first, and match by the serialized guards.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, package="my_dir/my_model")

compiled_f(*args)

compiled_f.save_package(prefix="/dir1")

...

compiled_f.load_package(prefix="/dir2")
```
zhxchen17 added a commit that referenced this pull request Mar 28, 2025
Following up PR #145381, we implement
a new API for compiling models using the cpp wrapper, and save/load
compiled artifacts to disk.

Package is now designed to be a per-torch.compile() object living with
the compilation context. Each time a recompilation happens, it will collect
the compiled artifacts into a lookup table. When a new set of inputs is
passed to the compiled callable, before we enter the dynamo cache, we will
perform a lookup in compile package first, and match by the serialized guards.

API names are tentative but the workflow roughly looks like the following:

```
def f(...): ...

compiled_f = torch.compile(f, package="my_dir/my_model")

compiled_f(*args)

compiled_f.save_package(prefix="/dir1")

...

compiled_f.load_package(prefix="/dir2")
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants