-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Support 0-batch size for nn.Linear. #27211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
aten/src/TH/generic/THTensorMath.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused... If the result is empty, what are these assignments doing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There're 2 main cases when failure in BLAS routines happens:
-
n or m is 0, which will result in the final matrix that 0 as one of the dimensions (as an example Linear called on a batch of size 0). In this case the only part that is needed from this function - resize of the final matrix to have correct output dimensions.
-
Another case when BLAS can fail is when k is 0, that can fail when one of the strides is 0 (happens on backward pass for Linear with an empty batch). In that case I'm getting rid of no-op part for ADDMM:
alpha * A x B + beta * C -> beta * C
aten/src/TH/generic/THTensorMath.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually beta is 1.0 so we can skip this loop all together
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
3 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kennyhorror has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: At the current moment of time nn.Linear (an it's interal functional code), will fail in THBlas: RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at caffe2/aten/src/TH/generic/THBlas.cpp:363 This diff is trying to fix this bug. As of now I was able to identify 2 possible places where changes needs to be done based on current dispatcher logic: 1. The file touched in this diff 2. caffe2/aten/src/THC/generic/THCTensorMathBlas.cu At the moment I didn't find a better places comparing to injecting logic to those files: the only non-generated function for forward pass, this + mm_mat2_backward function family on a backward pass. Pull Request resolved: #27211 Test Plan: New unit-tests are passing. Code that was failing earlier works. Need to test other backends. Differential Revision: D17599915 Pulled By: kennyhorror fbshipit-source-id: cd47fe9fddb3bad1be5ddaddb1c1d9b95a76d258
|
This pull request was exported from Phabricator. Differential Revision: D17599915 |
Summary: At the current moment of time nn.Linear (an it's interal functional code), will fail in THBlas: RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at caffe2/aten/src/TH/generic/THBlas.cpp:363 This diff is trying to fix this bug. As of now I was able to identify 2 possible places where changes needs to be done based on current dispatcher logic: 1. The file touched in this diff 2. caffe2/aten/src/THC/generic/THCTensorMathBlas.cu At the moment I didn't find a better places comparing to injecting logic to those files: the only non-generated function for forward pass, this + mm_mat2_backward function family on a backward pass. Pull Request resolved: pytorch/pytorch#27211 Test Plan: New unit-tests are passing. Code that was failing earlier works. Need to test other backends. Differential Revision: D17599915 Pulled By: kennyhorror fbshipit-source-id: 78894ce602d96aac2d6bf8c16a3fab43973e2d53
|
@kennyhorror merged this pull request in a891e92. |
|
This seems to break py2.7.9 test |
Summary: At the current moment of time nn.Linear (an it's interal functional code), will fail in THBlas: RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at caffe2/aten/src/TH/generic/THBlas.cpp:363 This diff is trying to fix this bug. As of now I was able to identify 2 possible places where changes needs to be done based on current dispatcher logic: 1. The file touched in this diff 2. caffe2/aten/src/THC/generic/THCTensorMathBlas.cu At the moment I didn't find a better places comparing to injecting logic to those files: the only non-generated function for forward pass, this + mm_mat2_backward function family on a backward pass. Pull Request resolved: pytorch#27211 Test Plan: New unit-tests are passing. Code that was failing earlier works. Need to test other backends. Differential Revision: D17599915 Pulled By: kennyhorror fbshipit-source-id: 78894ce602d96aac2d6bf8c16a3fab43973e2d53
Summary:
At the current moment of time nn.Linear (an it's interal functional code), will
fail in THBlas:
RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at caffe2/aten/src/TH/generic/THBlas.cpp:363
This diff is trying to fix this bug.
As of now I was able to identify 2 possible places where changes needs to be done based on current dispatcher logic:
At the moment I didn't find a better places comparing to injecting logic to those files:
the only non-generated function for forward pass, this + mm_mat2_backward function family on a backward pass.
Test Plan: New unit-tests are passing. Code that was failing earlier works. Need to test other backends.
Differential Revision: D17599915