-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Numpy-style broadcasting for all mathematical functions #1563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Here is numpy-style broadcasting for pointwise math and reduction functions. This also changes the keepdim default from True to False. I have a longer writeup in the works categorizing the numpy semantics and showing that these are the only backwards incompatible changes necessary to get numpy-style broadcasting (i.e. what's left from here only adds functionality that would currently give you an error). I'll link the writeup when that's done, but feel free to review before then (although you may wish to hold off on merging). |
|
let's aim to merge this after the nips deadline on Friday |
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love it. I went through and wrote some documentation for the new functions while reading. I didn't do a very careful "is this logic correct" code review.
docs/source/notes/broadcasting.rst
Outdated
| # x and y are not broadcastable, because x does not have at least 1 dimension | ||
| >>> x=torch.FloatTensor(5,1,4,1) | ||
| >>> y=torch.FloatTensor(3,1,1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| >>> x=torch.FloatTensor() | ||
| >>> y=torch.FloatTensor(2,2) | ||
| # x and y are not broadcastable, because x does not have at least 1 dimension | ||
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/THStorage.c
Outdated
| TH_API void THLongStorage_calculateExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est) { | ||
| ptrdiff_t ndim = THLongStorage_size(sizes); | ||
| long numUnsqueezed = ndim - tensorDim; | ||
| TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/THStorage.c
Outdated
| expandedSizes[i] = 1; | ||
| expandedStrides[i] = expandedSizes[i+1] * expandedStrides[i+1]; | ||
| } | ||
| TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/generic/THTensor.c
Outdated
| return 0; | ||
| } | ||
|
|
||
| int THTensor_(expand2)(THTensor *ra, THTensor *rb, THTensor *opa, THTensor *opb, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/Broadcast.py
Outdated
| from string import Template | ||
|
|
||
|
|
||
| class Broadcast(CWrapPlugin): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| self.assertEqual(t0.size(), r0.size()) | ||
| self.assertEqual(t1.size(), r1.size()) | ||
|
|
||
| # case 4: not broadcastable and not nEleme equal -- tested by test_fallback |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/THStorage.c
Outdated
| return 0; | ||
| } | ||
|
|
||
| TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ezyang thank you, I will incorporate these suggestions. If you are interested in how this works / future plans, you may find these notes interesting: https://github.com/gchanan/pytorch/wiki/Broadcasting-Notes. |
|
@gchanan could it be possible to have a global switch (at compile time if not possible at runtime) that would disable the automatic broadcasting? For example that just disable the cwrap plugin? |
|
@albanD good question and I understand the motivation for tightly controlling your own code. What I'm not sure about is how to make the distinction between your code and library code (or if that distinction is even well defined). To support what you want as written, we'd have to ensure that our libraries work both with and without broadcasting, which will greatly increase the maintenance burden. It's also not clear to me the distinction between user and library code is really clear anyway: let's say some NN layer is written in such a way that it e.g. multiples the input tensor by a (4,1) tensor and you pass in a (4) tensor. Was it library or user code that caused broadcasting? What I have now is a UserWarning if your code does not broadcast, but uses the old (deprecated) 1-d pointwise operations (you previously only needed nElem to match up for many functions). I could also add an optional warning for the backwards-incompatible case, that is, the sizes don't match, so you would have used the 1-d pointwise operations, but you are now broadcasting. It sounds like what you want is a warning if broadcasting changes the sizes at all (I'd be hestitant to make it an error, not a warning, given the argument about library code above). That shouldn't be too difficult to implement, although I'm not sure how useful it will be if our library code broadcasts a lot (you might get a lot of warnings). What do you think? |
|
Ho right I assumed that our library will stay broadcasting free (as it is right now). |
|
What about a per tensor switch instead of a global switch ? This will however increase the complexity in user code. |
|
If you pass in a non-broadcasting tensor to the library, does it switch to a broadcasting tensor? If it doesn't, you have the same issue of the library having to work with both (and having to define in what cases the flag is copied). |
|
I think this is ready for review and now implements all functions mentioned in https://github.com/gchanan/pytorch/wiki/Broadcasting-Notes (i.e. it no longer includes only pointwise mathematical and comparison functions like the title says). Re: turning off broadcasting, the only case I implemented this for was for copy, which allows you to pass a "broadcast" parameter (default: True). There were a number of places in the code where it was clearly intended to copy tensors as 1-d, which wasn't really true for other functions. This PR now also includes some warnings that you can enable to detect backwards incompatible changes. In particular:
I've manually enabled these and verified that the only places these warnings are triggered are in the tests are directly in tests, i.e. there are no library calls (at least that are covered by tests) where we rely on the old behavior. This should make turning on these warnings less noisy. |
48d91a1 to
6430b6e
Compare
|
rebased the commits. |
killeent
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future I think it would be helpful if we could come up with some sort of mechanism for breaking up PRs into smaller chunks, although I'm not blaming you for that @gchanan.
Due to the large scope of this PR, I didn't really look at the logic so much, I guess I'll just have to trust you. I added a few nits and questions where I didn't understand generally why things were done.
docs/source/notes/broadcasting.rst
Outdated
|
|
||
| >>> x=torch.FloatTensor(5,7,3) | ||
| >>> y=torch.FloatTensor(5,7,3) | ||
| # same shapes are always broadcastable |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
docs/source/notes/broadcasting.rst
Outdated
| >>> x=torch.FloatTensor(5,1,4,1) | ||
| >>> y=torch.FloatTensor(3,1,1) | ||
| # x and y are broadcastable |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| Many PyTorch operations support :any:`NumPy Broadcasting Semantics <numpy.doc.broadcasting>`. | ||
|
|
||
| In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be | ||
| automatically expanded to be of equal sizes (without making copies of the data). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| # arguments to broadcast specified argument (usually "self") against | ||
| # [inplace] will generate code for in-place function, which doesn't allow the in-place | ||
| # argument to be broadcast | ||
| # [fallback] if tensors aren't broadcastable, preserves "element number" pointwise behavior, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
tools/cwrap/plugins/Broadcast.py
Outdated
| # argument to be broadcast | ||
| # [fallback] if tensors aren't broadcastable, preserves "element number" pointwise behavior, | ||
| # where only number of elements need to match, and tensors are viewed as 1-dimensional. | ||
| # [dims] if the tensors shouldn't be broadcast to specific tensor or tensors, but a combination |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| TensorDst* dst = THPTypeInfo<TensorDst>::cdata(dst_); | ||
| TensorSrc* src = THPTypeInfo<TensorSrc>::cdata(src_); | ||
|
|
||
| TensorSrc *src_save = src; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| @@ -0,0 +1,188 @@ | |||
| #include "torch/csrc/cuda/THCP.h" | |||
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| @@ -0,0 +1,155 @@ | |||
| #ifndef THP_EXPAND_UTILS_H | |||
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| template class THPPointer<THPGenerator>; | ||
|
|
||
| static bool backCompatBroadcastWarn = false; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| THLongStorage_resize(output, ndim); | ||
| memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim); | ||
| THFree(expandedSizes); | ||
| return 0; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
docs/source/notes/broadcasting.rst
Outdated
| # x and y are not broadcastable, because x does not have at least 1 dimension | ||
| >>> x=torch.FloatTensor(5,1,4,1) | ||
| >>> y=torch.FloatTensor(3,1,1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
docs/source/notes/broadcasting.rst
Outdated
| # but: | ||
| >>> x=torch.FloatTensor(5,2,4,1) | ||
| >>> y=torch.FloatTensor(3,1,1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/common_nn.py
Outdated
| module_name='LogSoftmax', | ||
| input_size=(1, 3, 10, 20), | ||
| reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1).expand_as(i)).log_(), | ||
| reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False).expand_as(i)).log_(), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| sm2 = m2[:, 4] | ||
| res1 = torchfn(sm1, sm2) | ||
| # suppress broadcastable warning | ||
| with warnings.catch_warnings(record=True): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_torch.py
Outdated
| dims_full = [] | ||
| ndims = random.randint(1, 4) | ||
| for _ in range(ndims): | ||
| dims_full = dims_full + [random.randint(1, 8)] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/THStorage.c
Outdated
| } | ||
| } | ||
|
|
||
| expandedSizes[ i ] = max_dim_size; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/THStorage.c
Outdated
| return 0; | ||
| } | ||
|
|
||
| TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/lib/TH/generic/THTensor.c
Outdated
| THArgCheck(THLongStorage_size(sizes) >= THTensor_(nDimension)(tensor), 1, "the number of sizes provided \ | ||
| must be greater or equal to the number of dimensions in the tensor"); | ||
| THArgCheck(THTensor_(nDimension)(tensor) > 0, 0, "can't expand an empty tensor"); | ||
| THTensor* THTensor_(newExpand)(THTensor *tensor, THLongStorage *sizes, int raiseErrors) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| # reshape batches back into result | ||
| total_expansion = expand_batch_portion + (self_exp_size[-2], other_exp_size[-1]) | ||
| return self_expanded.bmm(other_expanded).view(*(total_expansion)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| enabled = property(get_enabled, set_enabled) | ||
|
|
||
| sys.modules[__name__] = Warning() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke I incorporated all your suggestions minus the THLongStorage suggestion (due to the @colesbury comment above) and removing raiseErrors from the expand,expand2,expand3 functions, since that comment thread didn't conclude. There were also new merge conflicts, so the last commit resolves them via a merge, but let me know if you want me to force push everything via a rebase. |
|
I guess we can leave the THSize change, but it'd be better to remove the error flag. |
|
ok, it took some effort to handle the error cases correctly without leaking memory, but I removed the error flag from THTensor. |
1) Line up trailing dimensions in broadcast docs. 2) remove unnecessary expand_as in common_nn test. 3) use view in tensor_str instead of resize_. 4) newExpand remove raiseErrors change. 5) clarify expandedSizes/expandedStrides parameters in inferExpandGeometry. 6) simplify inferSize2/inferSizeN implementations. 7) use new-style classes for warning.
take an error_buffer to return a proper error message while being able to handle memory management correctly from calling function.
They weren't documented as having those semantics, but tests on master show they do.
in Broadcast plugin when fallback = false.
|
I fixed the remaining lint and merged this into master! THIS WAS AWESOME @gchanan |
|
i forgot that parts of this PR have TH / THC etc. being touched upon. So I have to do the annoying reverse-merge scheme. So for now I reverted this PR on master (force-pushed, sorry), and I'll merge it in properly tomorrow-ish. |
f04b3a1 to
ca54693
Compare
|
this is now properly reverse-merged into master |
…40722 Upstream merge april
Skip test_typing to avoid `Error importing plugin "numpy.typing.mypy_plugin": No module named 'numpy.typing.mypy_plugin'` It happens because we have numpy==1.20.3 in some of our images. But `mypy` can be used only witn numpy>=1.21 We have numpy==1.20.3 in our images with python3.9 Will check numpy version in run_tests.py and add test_typing to ROCM_BLOCKLIST if numpy version less then 1.21 Fix ROCm/frameworks-internal#8497
Skip test_typing to avoid `Error importing plugin "numpy.typing.mypy_plugin": No module named 'numpy.typing.mypy_plugin'` It happens because we have numpy==1.20.3 in some of our images. But `mypy` can be used only witn numpy>=1.21 We have numpy==1.20.3 in our images with python3.9 Will check numpy version in run_tests.py and add test_typing to ROCM_BLOCKLIST if numpy version less then 1.21 Fix ROCm/frameworks-internal#8497 (cherry picked from commit 3b54c45)
all functions mentioned in: implements all functions mentioned in https://github.com/gchanan/pytorch/wiki/Broadcasting-Notes