Skip to content

Conversation

@Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented May 25, 2025

Second most requested op according to #154052

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@Isalia20 Isalia20 requested review from kulinseth and malfet as code owners May 25, 2025 13:58
@pytorch-bot
Copy link

pytorch-bot bot commented May 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154326

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c06e421 with merge base 53ecb81 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label May 25, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@Isalia20 Isalia20 added module: mps Related to Apple Metal Performance Shaders framework ciflow/mps Run MPS tests (subset of trunk) labels May 25, 2025
mtl_setArgs<8>(computeEncoder, input_strides_buffer, output_strides_buffer, source_strides_buffer);
}
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
MTLSize threadgroupSize = MTLSizeMake(std::min(numThreads, 256), 1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 256?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah probably a leftover, changed it to maxtotal thread groups

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label May 25, 2025
@Isalia20 Isalia20 added the ciflow/mps Run MPS tests (subset of trunk) label May 25, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented May 25, 2025

To add the ciflow label ciflow/mps please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label May 25, 2025
@malfet malfet added ciflow/mps Run MPS tests (subset of trunk) topic: improvements topic category labels May 29, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though I think it'll have problem with typecasts, but I'll submit followups

Comment on lines +6849 to +6859
test_cases = [
((2, 8, 4, 5), 0, [1]),
((8, 8, 4, 5), 0, [0, 3, 2, 7, 6]),
((2, 8, 4, 5), 1, [0, 3, 2, 7, 6]),
((2, 8, 4, 5), 2, [3, 0, 1]),
((2, 8, 4, 5), 3, [2, 3, 0]),
((2, 3, 3), -1, [1, 2])
]

for args in test_cases:
helper(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please stop using this anti-pattern and start using templatized arguments instead.
Also, does it really cover something that OpInfo is not atm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is for strided tensors only since I noticed they weren't being covered with OpInfo ones

@malfet
Copy link
Contributor

malfet commented May 29, 2025

@pytorchbot merge -f "Lint + MPS is green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged module: mps Related to Apple Metal Performance Shaders framework open source release notes: mps Release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants