Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Oct 16, 2024

Stack from ghstack (oldest at bottom):

The important comment:

# Enables a local, filesystem "profile" which can be used for automatic
# dynamic decisions, analogous to profile-guided optimization.  The idea is
# that if we observe that a particular input is dynamic over multiple
# iterations on one run, we can save a profile with this information so the
# next time we run we can just make it dynamic the first time around, skipping
# an unnecessary static compilation.  The profile can be soundly stale, if it
# is wrong, it just means we may make more things dynamic than was actually
# necessary (NB: this /can/ cause a failure if making something dynamic causes
# the compiler to stop working because you tickled a latent bug.)
#
# The profile is ONLY guaranteed to work if the user source code is 100%
# unchanged.  Applying the profile if there are user code changes is only
# best effort otherwise.  In particular, we identify particular code objects
# by filename, line number and name of their function, so adding/removing newlines
# will typically cause cache misses.  Once a profile is created, it will
# never be subsequently updated, even if we discover on a subsequent run that
# more inputs are dynamic (TODO: add some way of manually clearing the
# profile in a convenient way; TODO: add a way of not doing this behavior).
#
# Enabling this option can potentially change the automatic dynamic behavior
# of your program, even when there is no profile.  Specifically, we uniquely
# identify a code object by its filename/line number/name.  This means if you
# have multiple distinct code objects that have identical filename/line
# number, we will share automatic dynamic information across them (TODO:
# change default automatic dynamic behavior so it also crosstalks in this way)
automatic_dynamic_local_pgo = False

This is the dumbest, simplest thing I could manage to code in 1.5hrs. Here's how I tested it:

(/home/ezyang/local/c/pytorch-env) [[email protected] ~/local/c/pytorch (9b2e453e)]$ cat a.py
import torch

@torch.compile(backend="eager")
def f(x, y):
    return x + y

f(torch.randn(3), torch.randn(3))
f(torch.randn(4), torch.randn(4))

(/home/ezyang/local/c/pytorch-env) [[email protected] ~/local/c/pytorch (9b2e453e)]$ TORCH_LOGS=+torch._dynamo.pgo python a.py
V1015 20:19:06.171000 1676841 torch/_dynamo/pgo.py:31] [0/0] get_code_object_cache_path /data/users/ezyang/c/pytorch/a.py 3 f = /tmp/torchinductor_ezyang/dynamo/mr4g5czsmyoe3dhkfs6xyxlvavkokjqpdqe3q23ogz7wd546l6e
I1015 20:19:06.176000 1676841 torch/_dynamo/pgo.py:73] [0/0] put_code_object_cache len(frame_state)=3
V1015 20:19:06.177000 1676841 torch/_dynamo/pgo.py:74] [0/0] put_code_object_cache {'_id': 0, "L['x']": FrameStateSizeEntry(scalar=None, size=[3], stride=[1]), "L['y']": FrameStateSizeEntry(scalar=None, size=[3], stride=[1])}
I1015 20:19:06.249000 1676841 torch/_dynamo/pgo.py:73] [0/1] put_code_object_cache len(frame_state)=3
V1015 20:19:06.250000 1676841 torch/_dynamo/pgo.py:74] [0/1] put_code_object_cache {'_id': 0, "L['x']": FrameStateSizeEntry(scalar=None, size=[None], stride=[1]), "L['y']": FrameStateSizeEntry(scalar=None, size=[None], stride=[1])}
(/home/ezyang/local/c/pytorch-env) [[email protected] ~/local/c/pytorch (9b2e453e)]$ TORCH_LOGS=+torch._dynamo.pgo python a.py
V1015 20:19:56.495000 1697047 torch/_dynamo/pgo.py:31] [0/0] get_code_object_cache_path /data/users/ezyang/c/pytorch/a.py 3 f = /tmp/torchinductor_ezyang/dynamo/mr4g5czsmyoe3dhkfs6xyxlvavkokjqpdqe3q23ogz7wd546l6e
I1015 20:19:56.496000 1697047 torch/_dynamo/pgo.py:50] [0/0] get_code_object_cache hit len(frame_state)=3
V1015 20:19:56.496000 1697047 torch/_dynamo/pgo.py:51] [0/0] get_code_object_cache {'_id': 0, "L['x']": FrameStateSizeEntry(scalar=None, size=[None], stride=[1]), "L['y']": FrameStateSizeEntry(scalar=None, size=[None], stride=[1])}
V1015 20:19:56.496000 1697047 torch/_dynamo/pgo.py:63] [0/0] get_automatic_dynamic_initial_frame_state L['x'] = FrameStateSizeEntry(scalar=None, size=[None], stride=[1])
V1015 20:19:56.541000 1697047 torch/_dynamo/pgo.py:63] [0/0] get_automatic_dynamic_initial_frame_state L['y'] = FrameStateSizeEntry(scalar=None, size=[None], stride=[1])

You can see on the second run it only compiles once, and cache hits.

There is a lot of extra polish needed, hopefully mostly all noted in TODOs. One big gap is how exactly to invalidate this cache. The config might want to be some sort of epoch number you can bump up to invalidate old caches. Won't polish until we agree this is a good approach.

Signed-off-by: Edward Z. Yang [email protected]

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138052

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 114 Pending

As of commit 9835513 with merge base failed to retrieve merge base, please contact dev infra:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please commit the suggested changes from pytorch's linter.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Oct 16, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: a300afb
Pull Request resolved: #138052
# TODO: this scheme makes manual inspection of cache entries difficult,
# consider adding some breadcrumbs in the name for ease of use
r = os.path.join(
cache_dir(), "dynamo", sha256_hash(pickle.dumps((filename, firstlineno, name)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pickle.dumps is slow, I'd suggest f"{filename},{firstlineno},{name}".encode("utf-8")

try:
r = pickle.load(f)
except Exception:
log.warning("get_code_object_cache failed while reading %s", path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include error?

# have multiple distinct code objects that have identical filename/line
# number, we will share automatic dynamic information across them (TODO:
# change default automatic dynamic behavior so it also crosstalks in this way)
automatic_dynamic_local_pgo = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: this isn't wired up yet

ezyang added a commit to ezyang/pytorch that referenced this pull request Oct 23, 2024
Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: a300afb
Pull Request resolved: pytorch#138052
ezyang added a commit that referenced this pull request Oct 27, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: e1f0761
Pull Request resolved: #139001
@ezyang
Copy link
Contributor Author

ezyang commented Oct 27, 2024

Obsoleted by #139001

@ezyang ezyang closed this Oct 27, 2024
pytorchmergebot pushed a commit that referenced this pull request Oct 27, 2024
While working on automatic dynamic PGO (#138052) one abstract property I was looking for out of profile information is that it formed a semilattice: I could join together two profiles and get a merged profile that is consistent with the profiles that I saw in both cases. While working on this data structure that supported joins, I realized that the base automatic dynamic algorithm could be implemented in this way, therefore this refactor.

The basic recipe is that we now support a join operation on FrameStateSizeEntry. Intuitively, if you join two sizes that are equal, you get back that size (join(2, 2) == 2), but if you join two different sizes you get a special singleton auto_dynamic indicating that the size of the tensor is dynamic (join(2, 3) == auto_dynamic). So now, the automatic dynamic algorithm is: (1) compute the FrameStateSizeEntry that corresponds to the concrete values we've seen, and (2) join it into the ambient FrameStateSizeEntry. As a bonus, compiler collectives can buy into the same abstraction (we're simply distributing FrameStateSizeEntry from each node to every other node). For convenience, I also added the necessary `auto_unset` extra state which is the identity element (which makes our semilattice bounded from both top and bottom). Here, join(2, auto_unset) == 2.

While doing this, there was a complication: the infer stride algorithm wasn't technically a semilattice. Here, I did what I suggested in the original code review #130232 which is stop using a heuristic, and instead replicate the stride inference algorithm in automatic dynamic. This means that when I join strides together, I don't join their concrete values, instead, if a stride can be inferred as the contiguous stride for a particular inner dimension, then you represent it as InferStride(dim). There's an example in code which I recommend looking at.

Some other extra things that are happening in this PR:

* I tried to deduplicate the size/stride automatic dynamic logic as much as possible. So hopefully less code to review here.
* I had to reimplement all the logging. For the most part I tried to track the logging as closely to the original as possible, but I think we could be emitting less Chrome events here
* The `marked_dynamic` handling is still preserved as is, but I kind of don't like it and we should figure out how to put it somewhere else

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: #138717
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #138693
ezyang added a commit that referenced this pull request Oct 28, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 6dfd835
Pull Request resolved: #139001
ezyang added a commit that referenced this pull request Oct 28, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 868296a
Pull Request resolved: #139001
ezyang added a commit that referenced this pull request Oct 30, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 79b1b5c
Pull Request resolved: #139001
ezyang added a commit that referenced this pull request Nov 1, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 9e834e0
Pull Request resolved: #139001
pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2024
Previously: #138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497)
Pull Request resolved: #139001
Approved by: https://github.com/oulgen
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2024
Previously: #138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497)
Pull Request resolved: #139001
Approved by: https://github.com/oulgen
ezyang added a commit that referenced this pull request Nov 2, 2024
Previously: #138052
but the implementation is done from scratch, so I open a new PR.

Signed-off-by: Edward Z. Yang <[email protected]>

ghstack-source-id: 78f0975
Pull Request resolved: #139001
pytorchmergebot pushed a commit that referenced this pull request Nov 3, 2024
Previously: #138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: #139001
Approved by: https://github.com/oulgen
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Previously: pytorch#138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497)
Pull Request resolved: pytorch#139001
Approved by: https://github.com/oulgen
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Previously: pytorch#138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D65065497](https://our.internmc.facebook.com/intern/diff/D65065497)
Pull Request resolved: pytorch#139001
Approved by: https://github.com/oulgen
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Previously: pytorch#138052 but the implementation is done from scratch, so I open a new PR.

This implements the ability to save and load profiles of automatic dynamic decisions, so on subsequent runs we can directly make something automatically dynamic. Unlike the previous implementation, this cache is never enabled by default; instead, you have to specify a "job id" that says it's OK to share results. We will be able to automatically populate this id for internal MAST jobs but for generic OSS users you will have to explicitly opt into it.

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#139001
Approved by: https://github.com/oulgen
@github-actions github-actions bot deleted the gh/ezyang/2967/head branch November 28, 2024 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants