-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor] print triton float64 constants correctly #135260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135260
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a66a77e with merge base 58f2477 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
Also tested repro provided in #134720 |
| f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" | ||
| ) | ||
|
|
||
| def _print_Float(self, expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I admit I don't know if it's always valid to assume constants are float64. I'm operating on the assumption that any float literal originated from Python and is technically float64.
| def f(x): | ||
| return x * (0.12 * x.shape[0]) | ||
|
|
||
| x = torch.ones(200, device=GPU_TYPE, dtype=torch.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you parameterize the dtype here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@isuruf , I don't think so? At least, there's no dtype that I know about. We just have a sympy.core.numbers.Float object that we're printing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, totally misunderstood you comment. You mean run this test for a few different dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... except I got it wrong. omg. fix forthcoming
|
Unfortunately, I can't tell if this correct. In particular, I don't know what the correct types of the Triton IR are supposed to be in this codegen case... Use of float64 here seems right, or at least, it is consistent with some of the other float codegen. |
|
I believe this is only for constants in indexing expressions, not all constants... |
|
This would turn every computation within indexing with floats into triton-lang/triton#4613 would fix this in the repro from #134720, just upcasting to |
@lezcano , what problems would you anticipate? Of the people here, I certainly know the least about it. But if there's a float constant in an indexing expression, shouldn't that constant always be treated as fp64 to match eager semantics? |
|
GPUs have less dedicated silicon to compute fp64 than to compute fp32, even less so on consumer GPUs. Regading eager semantics, if you do |
|
@jansel WDYT about @lezcano's input here? How do you prefer to fix it? I've verified that triton-lang/triton#4613 indeed fixes the original repro from #134720. I don't know know how big of a deal it is to get that Triton change. Do we update the pin or is it a cherry-pick situation? |
|
While @lezcano is right, I kind of think we should ship this change anyway. Mostly because I don't actually see anyway to recover Python-style float64 semantics when we move them to CUDA without doing it in float64. This is different from int32/int64, where we have a chance of optimizing it by value ranges analysis; I don't see anyway to see how to go from float64 to float32 and guarantee bit for bit equivalence in the end. |
|
Ok, landing this now given @ezyang's guidance (and because we have an internal usage waiting on a fix). If the discussion continues and we decide this is the wrong choice, we can always revert. |
|
Yeah, indexing code is usually run in Python which is float64. I also think this will be uncommon and not matter much for perf. If it does we could optimize it. |
…tly" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Landed #135260 too soon and the test in that PR doesn't do exactly what I tested (actually test different dtypes). Test Plan: `python test/inductor/test_triton_kernels.py -k float64_constant` [ghstack-poisoned]
Summary: Landed #135260 too soon and the test in that PR doesn't do exactly what I tested (actually test different dtypes). Test Plan: `python test/inductor/test_triton_kernels.py -k float64_constant` Pull Request resolved: #135583 Approved by: https://github.com/isuruf, https://github.com/eellison, https://github.com/Skylion007
Summary: Landed pytorch#135260 too soon and the test in that PR doesn't do exactly what I tested (actually test different dtypes). Test Plan: `python test/inductor/test_triton_kernels.py -k float64_constant` Pull Request resolved: pytorch#135583 Approved by: https://github.com/isuruf, https://github.com/eellison, https://github.com/Skylion007
…ead of 1-element tensor Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK]) ` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? [ghstack-poisoned]
…ead of 1-element tensor Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK]) ` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? ghstack-source-id: fc004c4 Pull Request resolved: #136594
…ead of 1-element tensor Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? Pull Request resolved: #136594 ghstack-source-id: 4f2c28d
… creating f64 constant instead of 1-element tensor" Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK]) ` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D63360293](https://our.internmc.facebook.com/intern/diff/D63360293) [ghstack-poisoned]
…nstant instead of 1-element tensor" Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK]) ` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D63360293](https://our.internmc.facebook.com/intern/diff/D63360293) [ghstack-poisoned]
…ead of 1-element tensor (#136594) Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK]) ` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? Differential Revision: [D63465169](https://our.internmc.facebook.com/intern/diff/D63465169) Pull Request resolved: #136594 Approved by: https://github.com/mengluy0125, https://github.com/jansel
…ead of 1-element tensor This is a retry of #136594, which is having trouble landing. Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? [ghstack-poisoned]
…ead of 1-element tensor This is a retry of #136594, which is having trouble landing. Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? ghstack-source-id: 141efbd Pull Request resolved: #136858
…ead of 1-element tensor (#136858) This is a retry of #136594, which is having trouble landing. Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this: `tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])` #135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want? Differential Revision: [D63540693](https://our.internmc.facebook.com/intern/diff/D63540693) Pull Request resolved: #136858 Approved by: https://github.com/atalman
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang