Skip to content

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Sep 11, 2024

It's called from torch.unfold and one of the few remaining vestiges in MPSFallback.mm

Strongly inspired by CUDA implementation from

int64_t idx = index / width_col;
int64_t h_out = idx % height_col;
int64_t channel_in = idx / height_col;
int64_t channel_out = channel_in * kernel_height * kernel_width;
int64_t h_in = h_out * stride_height - pad_height;
int64_t w_in = w_out * stride_width - pad_width;
dt* col = data_col + (channel_out * height_col + h_out) * width_col + w_out;
const dt* im = data_im + (channel_in * height + h_in) * width + w_in;
for (int64_t i = 0; i < kernel_height; ++i) {
for (int64_t j = 0; j < kernel_width; ++j) {
int64_t h = h_in + i * dilation_height;
int64_t w = w_in + j * dilation_width;
*col = (h >= 0 && w >= 0 && h < height && w < width)
? im[i * dilation_height * width + j * dilation_width]
: static_cast<dt>(0);
col += height_col * width_col;
}
}
}

malfet and others added 3 commits September 11, 2024 06:43
It's called from `torch.unfold` and one of the few remaining vestiges in
MPSFallback
@malfet malfet requested a review from kulinseth as a code owner September 11, 2024 16:36
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135706

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 478c8e9 with merge base fc88ba2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Sep 11, 2024
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@malfet malfet added the topic: improvements topic category label Sep 11, 2024
@malfet malfet requested a review from albanD September 11, 2024 18:10
}
});
if (!batched_input) {
output.resize_({n_output_plane, output_length});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output.resize_({n_output_plane, output_length});
output.squeeze_(0);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same on the cuda code + moving it out of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3 block

INSTANTIATE_IM2COL(float2);
INSTANTIATE_IM2COL(half);
INSTANTIATE_IM2COL(half2);
#if __METAL_VERSION__ >= 310
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going to happen if I try to call this with bfloat on older metal? Do I get a nice error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see

TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");

Which will be triggered from here
case ScalarType::BFloat16:
checkSupportsBFloat16();
return "bfloat";

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good !

@malfet
Copy link
Contributor Author

malfet commented Sep 14, 2024

@pytorchbot merge -f "Lint and MPS are green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

malfet added a commit that referenced this pull request Sep 18, 2024
After #135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString`
pytorchmergebot pushed a commit that referenced this pull request Sep 18, 2024
After #135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString`

Pull Request resolved: #136295
Approved by: https://github.com/kit1980
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
After pytorch#135706 `getGatherScatterScalarType` returns exactly the same results as `scalarToMetalTypeString` , so delete the function and call `scalarToMetalTypeString`

Pull Request resolved: pytorch#136295
Approved by: https://github.com/kit1980
@github-actions github-actions bot deleted the malfet/mps-add-im2col branch October 14, 2024 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants