Skip to content

Conversation

@tenpercent
Copy link
Collaborator

@tenpercent tenpercent commented Aug 13, 2024

This PR enables dynamic shapes for the CK backend for gemm max autotune (see #125453).

This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.

We handle passing the problem sizes for the kernel call as well as for the benchmark call.

Testing

pytest test/inductor/test_ck_backend.py [-k dynamic]

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @zjing14

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133285

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 1ef26dc with merge base 3965f11 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm module: inductor module: rocm AMD GPU support for Pytorch labels Aug 13, 2024
@tenpercent tenpercent changed the title [ROCm][Indoctor][Draft] enable dynamic shapes [ROCm][Inductor][Draft] enable dynamic shapes Aug 13, 2024
@tenpercent tenpercent force-pushed the ck-unhardcode-mm-size branch 4 times, most recently from 92ab278 to 34035a1 Compare August 13, 2024 01:56
@tenpercent
Copy link
Collaborator Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased ck-unhardcode-mm-size onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout ck-unhardcode-mm-size && git pull --rebase)

Comment on lines 135 to 136
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth it to test with a different sized a like this?

new_a = torch.randn(2345, 256, **tensor_options)
Yy = mm(new_a, b)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what doessympy.expand do? I'm thinking if there's a better alternative in our big dynamic shape library =p

for example, there's this from the PT2 core library:

def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":

there's also this in Inductor which is very similar to above:

def simplify(self, expr: Expr):
return sympy.expand(expr).xreplace(self.replacements)

Copy link
Collaborator Author

@tenpercent tenpercent Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I took expand from some part of triton kernel code. Not sure which method here is actually correct. From the docs, it transforms polynomial expressions to their canonical form https://docs.sympy.org/latest/tutorials/intro-tutorial/simplification.html#expand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think it's okay to use sympy.expand then.

If this is used by codegen then we may benefit from sizevar's simplify which will substitute all symbols in self.replacements to make expr more canonical.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it would make sense to move the simplification to only kernel call site

Copy link
Contributor

@ColinPeppler ColinPeppler Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: what should size_args look like? does it look like smth c_int(s0) or c_int(123)?

I think with the size_hint call here, size_args would always be a scalar even when it's symbolic?

extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, call_args[len(expected_args) :])
)
# create the BenchmarkRequest
bmreq = ROCmBenchmarkRequest(
kernel_name=kernel_name,
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
extra_args=extra_args,

Copy link
Collaborator Author

@tenpercent tenpercent Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_args is a list of [M, N, K, LDA, LDB, LDC, LDD]. They all need to be scalars. For the kernel call, the scalar may look like c_int(s0) as s0 is obtained in the wrapper from one of the call args. For the benchmark call it should look like c_int(123) where 123 is the result of evaluating the size hint. Not sure if hinting always produces a scalar, we can provide a fallback in case it's not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for explaining, so we need scalar here.

Then this makes sense! size_hint(x) should produce scalar if x is a scalar or backed symint. It won't if x is an unbacked symint, but less common.

@tenpercent tenpercent force-pushed the ck-unhardcode-mm-size branch from 686936c to 2a40ba0 Compare August 14, 2024 03:59
@tenpercent
Copy link
Collaborator Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased ck-unhardcode-mm-size onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout ck-unhardcode-mm-size && git pull --rebase)

self.size_args() if hasattr(self, "size_args") else ()
) # subclass should define def size_args()
size_args_ints = [
V.graph.sizevars.symbolic_hint(arg) for arg in size_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sizevars.size_hint is the more preferred method here since it's more unbacked symint friendly =)

@ColinPeppler
Copy link
Contributor

Looks pretty good!

@tenpercent tenpercent marked this pull request as ready for review August 15, 2024 03:00
@tenpercent tenpercent changed the title [ROCm][Inductor][Draft] enable dynamic shapes [ROCm][CK][Inductor] enable dynamic shapes for CK backend to gemm max autotune Aug 15, 2024
@tenpercent
Copy link
Collaborator Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased ck-unhardcode-mm-size onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout ck-unhardcode-mm-size && git pull --rebase)

Test matmul with dynamic shapes
"""

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if this was disabled to avoid certain ROCm kernels.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's here to avoid numeric mismatches. Since the CK kernels do the computation with fp32 dtype, I just hope this setting enables fp32 accumulation in the aten counterpart

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right if tf32 is enabled then should do fp32 accum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's here to avoid numeric mismatches. Since the CK kernels do the computation with fp32 dtype, I just hope this setting enables fp32 accumulation in the aten counterpart

Btw, this setting doesn't do what you think it does. Aten always does accumulation in fp32. This setting simply allows aten to truncate to fp16 intermittently for things like fp16.

@tenpercent
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@tenpercent
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 16, 2024
@tenpercent
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants