Skip to content

Conversation

@alugorey
Copy link
Contributor

@alugorey alugorey commented Mar 22, 2023

Enables the hipSolver backend for ROCm builds

  • Minimum ROCm version requirement - 5.3
  • Introduces new macro USE_LINALG_SOLVER the controls enablement of both cuSOLVER and hipSOLVER
  • Adds hipSOLVER API to hipification process
  • combines hipSOLVER and hipSPARSE mappings into single SPECIAL map that takes priority among normal mappings
  • Torch api to be moved to hipsolver backend (as opposed to magma) include: torch.svd(), torch.geqrf(), torch.orgqr(), torch.ormqr()
  • Will enable 100+ linalg unit tests for ROCm

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97370

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit f508f9c:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: linalg_frontend release notes category labels Mar 22, 2023
@lezcano lezcano removed their request for review March 22, 2023 22:38
import scipy

def setLinalgBackendsToDefaultFinally(fn):
@wraps(fn)
Copy link
Collaborator

@jithunnair-amd jithunnair-amd Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this decorator definition moved to a different file, the lint error about import statement for wraps is legit

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 24, 2023
@jithunnair-amd jithunnair-amd added the rocm priority high priority ROCm PRs from performance or other aspects label Mar 27, 2023
@jithunnair-amd
Copy link
Collaborator

@malfet @ngimel Please review this PR with priority, if possible. It adds hipSolver support for ROCm.

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of introducing another call (which is mutually exclusive with hasCuSolver), wouldn't it be better to just reuse the same call (and perhaps renaming it to something vendor-agnosic, say hasAcceleratedLapack()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have anything specific in Python tests for hipBLAS or rocBLAS. Why should we have it for hipSOLVER? Why can't it be so that cuSOLVER == hipSOLVER on the ROCm platform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrapping the logic from skipCUDAIfNoCusolverAndNoHipsolver into skipCUDAIfNoCusolver

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think hipSPARSE required a considerably smaller number of intrusive code changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy-pasted comment should be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate PYTORCH_SOLVER_MAP needed here?
PYTORCH_SPARSE_MAP and PYTORCH_SOLVER_MAP should be unified and the comment describing this part of code updated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed in a file that tests meta tensors?

Comment on lines 122 to 127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is rocSOLVER added here? We don't add rocSPARSE for example.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is rocsolver necessary here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is extremely confusing, you are returning a non-contiguous tensor if make_contiguous argument is true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel Could you please elaborate? The stanza starting at line 2185 is only entered when make_contiguous is true. As such, we set the memory format to at::MemoryFormat::Contiguous on line 2189. Am I misunderstanding how this API works?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alugorey , why .mT() calls are needed there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet After investigating and talking amongst the team, we discovered this change to be an artifact left over from an earlier version of development. I have pushed up a new commit removing this unnecessary transposition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I think this addresses @ngimel's previous comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I've left this comment already, but why two defines is needed here? Why not #ifdef USE_GPU_SOLVER which in the common header file is define for both CUDA and ROCm platforms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet We keep these two defines separate because hipsolver still hasn't implemented all of the features cusolver supports. If you look in BatchLinearAlgebra.cpp, you'll see instances where the existing #ifdef USE_CUSOLVER was left alone without adding check for USE_HIPSOLVER. Keeping these two separate is how we can control the feature enablement for hipSOLVER. Once it is 1 to 1, we can consolidate.

@jeffdaily jeffdaily mentioned this pull request Apr 19, 2023
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Apr 20, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: alugorey / name: Andres Lugo (05a6da1a9af5e836cda80bfca16f446c377f167f)

@alugorey alugorey changed the base branch from main to viable/strict April 21, 2023 21:22
@alugorey alugorey force-pushed the hipsolver_enablement branch 2 times, most recently from 05a6da1 to dd0c046 Compare April 21, 2023 21:33
@alugorey
Copy link
Contributor Author

@malfet Had to rebase onto viable/strict and squash due to administrative issue. ready for review again.

@alugorey alugorey force-pushed the hipsolver_enablement branch 3 times, most recently from 617934f to 809fd70 Compare April 25, 2023 15:31
@jeffdaily jeffdaily requested a review from ngimel May 22, 2023 15:54
@alugorey alugorey force-pushed the hipsolver_enablement branch from 6a0d3b1 to b8e6ae3 Compare May 25, 2023 17:32
@facebook-github-bot
Copy link
Contributor

@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jithunnair-amd
Copy link
Collaborator

@alugorey Looking at the latest CI runs, I see two failures that seem to be related to this PR, based on history.
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCUDA::test_comprehensive_norm_nuc_cuda_float32
test_linalg.py::TestLinalgCUDA::test_pca_lowrank_cuda

Did you already check these?

@jithunnair-amd
Copy link
Collaborator

@pytorchbot merge -f "CI failures are unrelated to PR"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

shaoyf42 pushed a commit to shaoyf42/pytorch that referenced this pull request Jun 1, 2023
Enables the hipSolver backend for ROCm builds
--------------------------------------------------------------------------

- Minimum ROCm version requirement - 5.3
- Introduces new macro USE_LINALG_SOLVER the controls enablement of both cuSOLVER and hipSOLVER
- Adds hipSOLVER API to hipification process
- combines hipSOLVER and hipSPARSE mappings into single SPECIAL map that takes priority among normal mappings
- Torch api to be moved to hipsolver backend (as opposed to magma) include: torch.svd(), torch.geqrf(), torch.orgqr(), torch.ormqr()
- Will enable 100+ linalg unit tests for ROCm

Pull Request resolved: pytorch#97370
Approved by: https://github.com/malfet
jeffdaily added a commit to ROCm/pytorch that referenced this pull request Oct 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/unstable Run all experimental or flaky jobs on PyTorch unstable workflow Merged module: inductor module: rocm AMD GPU support for Pytorch open source release notes: linalg_frontend release notes category rocm priority high priority ROCm PRs from performance or other aspects rocm This tag is for PRs from ROCm team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants