-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor] Preserve metadata across replace_by_example and register_replacement patterns #138089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Preserve metadata across replace_by_example and register_replacement patterns #138089
Conversation
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138089
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 89589b1 with merge base e4ad028 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…terns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
| def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None: | ||
| # transfer metadata after pattern matching occurs. | ||
| # skip "val" and "tensor_meta" because this info is too specific; it's unlikely | ||
| # to remain accurate after pattern matching has occurred. | ||
| new_meta.update( | ||
| (k, v) for k, v in old_meta.items() if k not in {"val", "tensor_meta"} | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a pre-existing utility to do something like this? And if not - am I missing anything other than "val" and "tensor_meta" that should be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have
Line 93 in 9b01d17
| _COPY_META_FIELDS = [ |
although I'm not sure it's directly applicable.
The other thing is that, in some places, we do use the FakeTensorUpdater to ensure that the faketensor metadata is accurate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, thanks! _COPY_META_FIELDS seems good enough to me. It looks like the faketensor metadata propagation is already handled in these cases and I don't see any issues with it, so I'll stick with just using _COPY_META_FIELDS for now.
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…terns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. ghstack-source-id: cba51ff Pull Request resolved: #138089
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
| ) | ||
| if len(self.nodes) > 0: | ||
| for n in replacement.graph.nodes: | ||
| _transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we have multiple nodes in self.nodes? Can we end up with more confusion downstream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I'd argue it's still better than having nothing though:
- In practice, it looks like very few replace_by_example patterns actually have multiple source nodes.
- The main case where you see this is in the (
add(z, mm(x, y)) -> addmm(z, x, y)case.) - original_aten may be confusing, but you'll also get other information like source location that will still be useful.
Do you have a suggestion on a better way to get this information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or alternatively, how about we only preserve metadata if self has exactly 1 node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or alternatively, how about we only preserve metadata if self has exactly 1 node?
IMO, this is potentially less confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed the PR to do this (only preserve metadata if self has exactly 1 node)
…terns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
… register_replacement patterns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This also adds metadata for to register_replacement patterns, including pad_mm. This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
… register_replacement patterns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This also adds metadata for to register_replacement patterns, including pad_mm. This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. ghstack-source-id: 3c06072 Pull Request resolved: #138089
… register_replacement patterns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This also adds metadata for to register_replacement patterns, including pad_mm. This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@davidberard98 I was watching CI this weekend - suspect this might be breaking |
|
@pytorchbot revert -m 'Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk https://hud.pytorch.org/pytorch/pytorch/commit/fb44658415e50b5be6a187ff3f14243c0fdf3daf' -c nosignal inductor/test_pad_mm.py::PadMMTest::test_original_aten_preserved_pad_mm GH job link HUD commit link If possible, I think a smaller tensor is needed. Otherwise, you could add test_pad_mm into https://github.com/pytorch/pytorch/blob/main/test/run_test.py#L225 to make it run serially and can use all the GPU memory. |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@davidberard98 your PR has been successfully reverted. |
…gister_replacement patterns (#138089)" This reverts commit fb44658. Reverted #138089 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_original_aten_preserved_pad_mm test runs OOM in trunk https://hud.pytorch.org/pytorch/pytorch/commit/fb44658415e50b5be6a187ff3f14243c0fdf3daf ([comment](#138089 (comment)))
… register_replacement patterns" replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This also adds metadata for to register_replacement patterns, including pad_mm. This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755) [ghstack-poisoned]
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example. This fixes the issue by copying metadata from the original nodes to the replacement nodes. If there are multiple original nodes (e.g. if you're pattern matching `add(mm(x, y), z) -> addmm(z, x, y)`, then just arbitrarily pick one of the source nodes. ghstack-source-id: 3f3ddc4 Pull Request resolved: #138089
|
Shrunk the test to use 2MB. Tests are passing and the memory is smaller, so I'll try merging again. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Motivation Fix #138577. # Solution 1. All UTs in `test/inductor/test_compiled_optimizers.py` are fixed by #134170 2. UT in `test/inductor/test_pattern_matcher.py` is introduced by #138089, we will skip this UT due to the unsupported feature `max_autotune_gemm_backends:Triton`. 3. We have a new impl related to `histc`, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py` 4. We support `avg_pool3d` for `fp16` data type, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py` 5. CUDA-bias code is introduced by #138472, we just generalize it to `GPU_TYPE`. # Additional Context > Why update torch-xpu-ops commit pin here? We have to update commit pin to avoid the build failure raised by the code change [C10_UNUSED](#138364). > What does the feature of torch-xpu-ops update? 1. Add some foreach ops, like `unary ops` and `foreach_clamp_max` etc; 2. Add some maxpool ops forward and backward, like `averge_pool3d` and `max_pool3d` 3. Add some other ops, like `log_normal_`, `index_copy`, and `mode` etc; 4. fix build failure related to `C10_UNUSED`; Pull Request resolved: #138548 Approved by: https://github.com/malfet, https://github.com/EikanWang
Stack from ghstack (oldest at bottom):
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular,
meta["original_aten"]would be lost; that meant that when generating triton kernel names, you could get empty names liketriton_tem_fused_0if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.This also adds metadata for to register_replacement patterns, including pad_mm.
This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a
add(z, mm(x, y)), then the metadata won't be transferred right now.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov
Differential Revision: D64480755