Skip to content

Conversation

@anjali411
Copy link
Contributor

@anjali411 anjali411 commented Aug 4, 2020

Stack from ghstack:

Ports torch.bmm and torch.baddbmm from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.

Closes #24539

Differential Revision: D24893511

anjali411 added a commit that referenced this pull request Aug 4, 2020
ghstack-source-id: b737e98
Pull Request resolved: #42553
@dr-ci
Copy link

dr-ci bot commented Aug 4, 2020

💊 CI failures summary and remediations

As of commit 909b366 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 162 times.

@anjali411 anjali411 requested a review from zasdfgbnm August 4, 2020 19:30
@anjali411 anjali411 changed the title Port bmm and baddbmm from TH to ATen [WIP] Port bmm and baddbmm from TH to ATen Aug 4, 2020
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 3eba02a
Pull Request resolved: #42553
@anjali411 anjali411 requested a review from zasdfgbnm August 5, 2020 22:26
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 047dffe
Pull Request resolved: #42553
@anjali411 anjali411 changed the title [WIP] Port bmm and baddbmm from TH to ATen Port bmm and baddbmm from TH to ATen Aug 6, 2020
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not finished yet. Will post more comment later.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 7, 2020
ghstack-source-id: f45d4e1
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 12, 2020
ghstack-source-id: 61c04d4
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 13, 2020
ghstack-source-id: b43d308
Pull Request resolved: #42553
@ngimel
Copy link
Collaborator

ngimel commented Aug 13, 2020

Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices).

@zasdfgbnm
Copy link
Collaborator

@anjali411 Could you please rebase? Looks like there are lots of flaky tests.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 2a2540a
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 3469afd
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @ngimel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498

So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16 and torch.bfloat16 in a follow-up PR and leave this one as is.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 0bbc3aa
Pull Request resolved: #42553
@anjali411 anjali411 requested a review from zasdfgbnm November 11, 2020 18:57
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: d7ff7cd
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: c52815b
Pull Request resolved: #42553
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this!

@codecov
Copy link

codecov bot commented Nov 12, 2020

Codecov Report

Merging #42553 (909b366) into gh/anjali411/46/base (4738672) will increase coverage by 0.00%.
The diff coverage is 0.00%.

@@                  Coverage Diff                  @@
##           gh/anjali411/46/base   #42553   +/-   ##
=====================================================
  Coverage                 81.22%   81.22%           
=====================================================
  Files                      1837     1837           
  Lines                    198087   198087           
=====================================================
+ Hits                     160893   160897    +4     
+ Misses                    37194    37190    -4     

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in e1ee3bf.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/46/head branch November 16, 2020 15:17
facebook-github-bot pushed a commit that referenced this pull request Nov 18, 2020
Summary:
Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests.

Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`.

Pull Request resolved: #47910

Reviewed By: zhangguanheng66

Differential Revision: D25052130

Pulled By: mruberry

fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No harm done, workaround is good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 no my bad! we should fix that to avoid confusion in future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants