Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

torch.cholesky_solve now works for complex inputs on GPU.
I moved the existing tests to test_linalg.py and modified them to test complex and float32 dtypes.
Differentiation also works correctly with complex inputs now.

Ref. #33152

@IvanYashchuk IvanYashchuk added module: complex Related to complex number support in PyTorch module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Oct 29, 2020
@IvanYashchuk IvanYashchuk force-pushed the complex-cholesky-solve branch from 4c73f09 to d2ffbf2 Compare October 29, 2020 11:03
@dr-ci
Copy link

dr-ci bot commented Oct 29, 2020

💊 CI failures summary and remediations

As of commit 09d7262 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 43 times.

@codecov
Copy link

codecov bot commented Oct 29, 2020

Codecov Report

Merging #47047 (09d7262) into master (b726a1b) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master   #47047   +/-   ##
=======================================
  Coverage   80.79%   80.79%           
=======================================
  Files        1865     1865           
  Lines      201074   201074           
=======================================
+ Hits       162456   162459    +3     
+ Misses      38618    38615    -3     

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 29, 2020
@facebook-github-bot
Copy link
Contributor

Hi @IvanYashchuk!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

MAGMAQueue magma_queue(self.get_device());

constexpr int64_t batch_limit = 262140;
int64_t batch_limit = self.is_complex() ? 65535 : 262140;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have different batch limit for complex and non-complex dtypes? can you link me to where this is documented?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it's documented somewhere, I determined this value via experiments.

Copy link
Collaborator Author

@IvanYashchuk IvanYashchuk Nov 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA limits kernel launches to y and z grid dimension to 65535. Maybe for non-complex dtypes batching is implemented differently allowing 262140 batches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced with @ngimel offline. We should check the magma manual, and better document this difference in the batch_limit since the original comments are uninformative.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not documented in magma.
CUDA limits kernel launch configurations of y and z grid dimensions to 65535.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications

Maximum x-dimension of a grid of thread blocks 2^31-1
Maximum y- or z-dimension of a grid of thread blocks 65535

I haven't checked the source code for how batching is done for non-complex dtypes, but apparently, complex variants use z-dimension of a grid of thread blocks for batching.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent a little time looking at the complex path and didn't figure it out, but I did see this:

    if ( n > 2048 ) {   
        #ifndef MAGMA_NOWARNING
        printf("=========================================================================================\n"
               "   WARNING batched routines are designed for small sizes. It might be better to use the\n"
               "   Native/Hybrid classical routines if you want good performance.\n"
               "=========================================================================================\n");
        #endif
    }           
            

in magma_cpotrf_lg_batched

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should use cusolver for those, if we don't already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @heitorschueroff, @xwang233 Can you guys please create a tracking issue which linalg functions under which conditions use magma or cusolver or cublas, and which functions still need to be weaned off magma and switched to cusolver?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll create a tracking issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233's issue is here: #47953

A = root.tril()
return torch.cholesky_solve(b, A, upper)

gradcheck(func, [root, b, upper])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk please move the autograd tests to common_methods_invocations.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think common_methods_invocations.py does not allow specifying the input function to be tested, it allows specifying only the postprocessing function.
Finite differencing doesn't work correctly for torch.cholesky_solve directly, therefore

def func(A, b, upper):
    if upper:
        A = A.triu()
    else:
        A = A.tril()
    return torch.cholesky_solve(b, A, upper)

is tested instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I also synced with @mruberry offline and we came to the conclusion it's ok to add autograd tests in test_linalg.py.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@IvanYashchuk IvanYashchuk force-pushed the complex-cholesky-solve branch from a55cd87 to 24979e6 Compare November 29, 2020 16:15
@IvanYashchuk
Copy link
Collaborator Author

@mruberry, I think we are ready to import this PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Collaborator

mruberry commented Dec 3, 2020

Sorry @IvanYashchuk, looks like this picked up a merge conflict. Would you rebase?

@IvanYashchuk
Copy link
Collaborator Author

Done.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 85121a7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: complex Related to complex number support in PyTorch module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants