Skip to content

Conversation

@ydwu4
Copy link
Contributor

@ydwu4 ydwu4 commented Aug 16, 2023

Stack from ghstack (oldest at bottom):

Motivation:
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures.

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:

def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

Implementation:

  1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

Note: More lines are printed for debug log due to newly added context manager and guard adds .

Test Plan:
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @anijain2305

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107337

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit ce3c4fc with merge base 916183a (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…hanges"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
…hanges"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
…hanges"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 17, 2023
@ydwu4 ydwu4 changed the title [dynamo] Add BACKEND_MATCH guard to recompile if backend changes [dynamo] Add BACKEND_MATCH guard to guard backend changes Aug 17, 2023
@ydwu4 ydwu4 changed the title [dynamo] Add BACKEND_MATCH guard to guard backend changes [dynamo] Add BACKEND_MATCH guard to detect and recompile when backend changes Aug 17, 2023
@ydwu4 ydwu4 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 17, 2023
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Aug 17, 2023
"calling `torch._dynamo.reset()` to take effect"
)
most_recent_backend = compiler_fn
guarded_backend_cache[id(most_recent_backend)] = most_recent_backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would be better to have a thread local which is set when we enter a torch.compile region specifying what the current backend is. At the very least it needs to be thread local, which it doesn't seem like most_recent_backend is. (Yes, I know some of this is preexisting problems, but now you know.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified accordingly and added a test for multi-threads but I'm not sure if the implementation is good enough.

…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
@huydhn
Copy link
Contributor

huydhn commented Sep 8, 2023

@pytorchbot revert -m 'Sorry for reverting your change but inductor perf smoke test starts to regress after this' -c ignoresignal

Here is an example failure https://hud.pytorch.org/pytorch/pytorch/commit/1a64ec7dd48408d6839a5c2cceb55b0c4be2243b

2023-09-07T23:51:08.8980334Z + python benchmarks/dynamo/check_hf_bert_perf_csv.py -f /var/lib/jenkins/workspace/test/test-reports/inductor_training_smoketest.csv
2023-09-07T23:51:09.3421344Z hf_Bert                            0.997172
2023-09-07T23:51:09.3421670Z 
2023-09-07T23:51:09.3421819Z Error 1 models performance regressed
2023-09-07T23:51:09.3422217Z     hf_Bert

@pytorch pytorch deleted a comment from pytorch-bot bot Sep 8, 2023
@huydhn
Copy link
Contributor

huydhn commented Sep 8, 2023

@pytorchbot revert -m 'Sorry for reverting your change but inductor perf smoke test starts to regress after this' -c ignoredsignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@ydwu4 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 8, 2023
… backend changes (#107337)"

This reverts commit 1a64ec7.

Reverted #107337 on behalf of https://github.com/huydhn due to Sorry for reverting your change but inductor perf smoke test starts to regress after this ([comment](#107337 (comment)))
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

2. Then newly added context manager and guard adds more lines for debug log so we change the uppper limit from 50 to 55.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Sep 8, 2023
@ydwu4 ydwu4 reopened this Sep 8, 2023
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

2. Then newly added context manager and guard adds more lines for debug log so we change the uppper limit from 50 to 55.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Sep 11, 2023
@zou3519 zou3519 removed their request for review September 12, 2023 17:20
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.


Note: More lines are printed for debug log due to newly added context manager and guard adds .

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
…hen backend changes"


**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures. 

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()


# Inside torch.cond, we'd like to do something like 
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else. 
# Dynamo will use the cached code of foo for "eager" backend 
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.


Note: More lines are printed for debug log due to newly added context manager and guard adds .

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov anijain2305

[ghstack-poisoned]
ydwu4 added a commit that referenced this pull request Sep 13, 2023
@ydwu4
Copy link
Contributor Author

ydwu4 commented Sep 14, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants