Skip to content

Conversation

@Isalia20
Copy link
Collaborator

@Isalia20 Isalia20 commented Feb 9, 2025

PR #145701 didn't have experimental version of cholesky. This PR adds that version

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146799

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4ea9ce5 with merge base 91c4bf3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Feb 9, 2025
@Isalia20 Isalia20 changed the title Mps cholesky ex version [MPS] cholesky ex version Feb 9, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Feb 9, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@malfet malfet added the ciflow/mps Run MPS tests (subset of trunk) label Feb 11, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 11, 2025

To add the ciflow label ciflow/mps please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/mps Run MPS tests (subset of trunk) label Feb 11, 2025
@malfet malfet added topic: improvements topic category ciflow/mps Run MPS tests (subset of trunk) labels Feb 11, 2025
Comment on lines +1151 to +1152
out.tril_();
upper ? out.transpose_(ndim - 2, ndim - 1) : out;
Copy link
Collaborator

@nikitaved nikitaved Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will silently alter the stride structure of out if upper == true. It is better be upper ? out.triu_() : out.tril_().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the same. The kernel does decomposition in the lower part of the matrix. If you do out.triu_() instead of out.tril_ -> transpose, then you get the upper part of the matrix which isn't really the correct output.

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some stride assumptions in the kernel, or is it stride-agnostic? If it is stride-agnostic, then the kernel could be run on the transposed variant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It assumes that input is row major(contiguous)

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out can be provided externally as column-major. What would happen in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I printed data ptr inside the mps function and outside in python:

import torch

out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)

x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x

data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}")  # lowercase hex
torch.linalg.cholesky(x, out=out)
print(f"0x{out.data_ptr():x}")

Yields:

0x10a4d68d0
0x10fb19150
0x10a4d68d0

First one being print from python, 2nd one being before launching the kernel from C++ and 3rd one being again from python. So yeah confirmed

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per https://github.com/pytorch/pytorch/pull/146799/files#r1952464144, this is expected. Sorry for the confusion. But we should have issues when out is contiguous and upper=True it seems.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues from what I check:

import torch

out = torch.rand(3, 3, 3, device="mps").permute(2, 1, 0)

x = torch.rand(3, 3, 3, device="mps")
x = x.mT @ x

data_ptr = out.data_ptr()
print(f"0x{data_ptr:x}")  # lowercase hex
print(out.stride())
print(out.is_contiguous())
res1 = torch.linalg.cholesky(x, out=out, upper=True)
res2 = torch.linalg.cholesky(x.cpu(), out=out.cpu(), upper=True)
print(f"0x{out.data_ptr():x}")
print(out.stride())
torch.testing.assert_close(res1.cpu(), res2)
0x113f70cc0
(1, 3, 9)
False
0x114f3a510
0x113f70cc0
(1, 3, 9)

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Isalia20 , could you remove permute so that out is contiguous? In the Meta function, as per your modification, out is re-used only if it is contiguous.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the issue now:

0x10bc7b840
(9, 3, 1)
True
0x10bc7b840
0x10bc7b840
(9, 1, 3)

Copy link
Collaborator

@nikitaved nikitaved left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you! I have left some comments regarding the silent stride-altering behavior in out and the values of the info vector.


// L
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true);
auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/A.device().type() != at::kMPS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is MPS different?

Copy link
Collaborator Author

@Isalia20 Isalia20 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MPS Kernel assumes row-major layout for the matrix where it does the decomposition

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the kernel be made row-major/col-major agnostic so as to be able preserve the consistency across backends?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look ~next week to see if I can make it work for col-major so we don't need to make it row major for MPS only, but why do we want to preserve consistency across backends? Lot of ops on MPS use row major layout and require contiguous call on it before passing it to some MPS kernel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In linalg LAPACK seems like the source of truth, and it is written in Fortran where col-major is the standard layout :(

Copy link
Collaborator

@nikitaved nikitaved Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we can re-use the kernel without that much code change (i.e. no need to make it stride-agnostic for now). In the Meta function we request C-contiguous when upper=False and F-contiguous when upper=True for the MPS. Then we only need to remove the line upper ? out.transpose_(...) : out (and probably replace it with out.tril_() : out.triu_(). Or something along these lines. Should resolve the issue for now with out, before the kernel is adapted for better memory accesses when in column-major mode...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried it but I'm afraid it doesn't work. I'll address this in the followup PR with the kernel change for column major mode rather than going into the rabbit hole now for a temporary fix

@Isalia20
Copy link
Collaborator Author

Thanks, I'll address the comments a little later today

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 11, 2025
Comment on lines +6541 to +6542
output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper)
output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us also check that info is the same since its behavior is altered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_cpu and output_mps is a tuple of L and info tensors so assertEqual is comparing both of them. Do you mean to add a separate test where info might be >1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when erroring on non-psd inputs :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it a bit later today and also adapt the error message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added better error message

@malfet
Copy link
Contributor

malfet commented Feb 13, 2025

@pytorchbot merge -f "MPS is green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
PR pytorch#145701 didn't have experimental version of cholesky. This PR adds that version

Pull Request resolved: pytorch#146799
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants