-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ROCm][CK][Inductor] enable dynamic shapes for CK backend to gemm max autotune #133285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133285
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 1ef26dc with merge base 3965f11 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
92ab278 to
34035a1
Compare
|
@pytorchbot rebase -s |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
34035a1 to
686936c
Compare
test/inductor/test_ck_backend.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth it to test with a different sized a like this?
new_a = torch.randn(2345, 256, **tensor_options)
Yy = mm(new_a, b)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what doessympy.expand do? I'm thinking if there's a better alternative in our big dynamic shape library =p
for example, there's this from the PT2 core library:
pytorch/torch/fx/experimental/symbolic_shapes.py
Line 4679 in c17d26c
| def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": |
there's also this in Inductor which is very similar to above:
pytorch/torch/_inductor/sizevars.py
Lines 91 to 92 in c17d26c
| def simplify(self, expr: Expr): | |
| return sympy.expand(expr).xreplace(self.replacements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I took expand from some part of triton kernel code. Not sure which method here is actually correct. From the docs, it transforms polynomial expressions to their canonical form https://docs.sympy.org/latest/tutorials/intro-tutorial/simplification.html#expand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think it's okay to use sympy.expand then.
If this is used by codegen then we may benefit from sizevar's simplify which will substitute all symbols in self.replacements to make expr more canonical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it would make sense to move the simplification to only kernel call site
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what should size_args look like? does it look like smth c_int(s0) or c_int(123)?
I think with the size_hint call here, size_args would always be a scalar even when it's symbolic?
pytorch/torch/_inductor/codegen/rocm/rocm_template.py
Lines 93 to 101 in 9de023d
| extra_args = V.graph.sizevars.size_hints( | |
| map(sympy.expand, call_args[len(expected_args) :]) | |
| ) | |
| # create the BenchmarkRequest | |
| bmreq = ROCmBenchmarkRequest( | |
| kernel_name=kernel_name, | |
| input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), | |
| output_tensor_meta=TensorMeta.from_irnodes(self.output_node), | |
| extra_args=extra_args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_args is a list of [M, N, K, LDA, LDB, LDC, LDD]. They all need to be scalars. For the kernel call, the scalar may look like c_int(s0) as s0 is obtained in the wrapper from one of the call args. For the benchmark call it should look like c_int(123) where 123 is the result of evaluating the size hint. Not sure if hinting always produces a scalar, we can provide a fallback in case it's not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for explaining, so we need scalar here.
Then this makes sense! size_hint(x) should produce scalar if x is a scalar or backed symint. It won't if x is an unbacked symint, but less common.
686936c to
2a40ba0
Compare
|
@pytorchbot rebase -s |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
2a40ba0 to
23be5a7
Compare
| self.size_args() if hasattr(self, "size_args") else () | ||
| ) # subclass should define def size_args() | ||
| size_args_ints = [ | ||
| V.graph.sizevars.symbolic_hint(arg) for arg in size_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think sizevars.size_hint is the more preferred method here since it's more unbacked symint friendly =)
|
Looks pretty good! |
|
@pytorchbot rebase -s |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
698570b to
b1e83fd
Compare
| Test matmul with dynamic shapes | ||
| """ | ||
|
|
||
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering if this was disabled to avoid certain ROCm kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it's here to avoid numeric mismatches. Since the CK kernels do the computation with fp32 dtype, I just hope this setting enables fp32 accumulation in the aten counterpart
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right if tf32 is enabled then should do fp32 accum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it's here to avoid numeric mismatches. Since the CK kernels do the computation with fp32 dtype, I just hope this setting enables fp32 accumulation in the aten counterpart
Btw, this setting doesn't do what you think it does. Aten always does accumulation in fp32. This setting simply allows aten to truncate to fp16 intermittently for things like fp16.
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR enables dynamic shapes for the CK backend for gemm max autotune (see #125453).
This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.
We handle passing the problem sizes for the kernel call as well as for the benchmark call.
Testing
pytest test/inductor/test_ck_backend.py [-k dynamic]cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @zjing14