-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement torch.tensordot #10025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement torch.tensordot #10025
Conversation
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
From a first look and by comparing side by side against the numpy implementation this looks good.
Also, I'd prefer if we named the ATen functions without an underscore.
aten/src/ATen/native/Linear.cpp
Outdated
|
|
||
| // implements tensordot, a matrix-multiplication-like contraction, but the dimensions given | ||
| // in the two dimension lists | ||
| Tensor _tensordot(const Tensor& input1, const Tensor& input2, IntList dims1, IntList dims2) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
| :attr`b` respectively | ||
| When called with an integer argument :attr:`dims` = :math:`d`, and the number of | ||
| dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math`n`, repsectively, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Also fix typo in the docstring Thank you, fmassa for the review comments.
|
Nice! Just wondering if there is a good reason to differ from the numpy api regarding renaming the keyword argument |
|
PyTorch consistently uses dim instead of axis, so I'm trying to blend in.
|
|
So I don't know that the CI failures are from the patch. Could you have it retested, please? |
|
@pytorchbot retest this please |
torch/functional.py
Outdated
| if dim is None: | ||
| return torch.sort(input, -1, descending)[1] | ||
| return torch.sort(input, dim, descending)[1] | ||
| return torch.sort(input, dim, descending)[1] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM other than small nits
torch/functional.py
Outdated
| dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math`n`, respectively, | ||
| it computes | ||
| :math:`r_{i_0,...,i_{m-d}, i_d,...,i_n}` |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
| [4796., 5162.], | ||
| [4928., 5306.]]) | ||
| >>> a = torch.randn(3, 4, 5, device=d) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/functional.py
Outdated
| :math:`= \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} * b_{k_0,...,k_{d-1}, i_d,...,i_n}`. | ||
| When called with :attr:`dims` of the list form, the given dimensions will be contracted | ||
| in place of the last :math:`d` of :attr:`a` and the first :math:`d` of `b`. The sizes |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Anything I can do to move this further? |
|
@t-vi i'll push this through |
Summary: Fixes: #8988 Pull Request resolved: pytorch/pytorch#10025 Reviewed By: ezyang Differential Revision: D9540967 Pulled By: yf225 fbshipit-source-id: 6ba2a7777162983977db884b693e6f4543b31aeb
resolve conflict in data parallel model * master: (201 commits) Add cost inference to ConvGradient and WeightedSum operators (pytorch#10744) Move collapse dims into a single place (pytorch#11272) Fix some more warnings (pytorch#11257) Fix the batchnorm onnx exporting when affine=False Improve error message to include return types too (pytorch#11245) Check doxygen output in travis (pytorch#11124) Accept more numpy scalars as doubles (pytorch#9659) Fixed log message (pytorch#10874) Fix to distribution.__repr__ with lazy attributes (pytorch#11263) Add import export step to end to end tests Add complex hooks for out of tree complex implementation. (pytorch#11216) Unify opt flag for cmake codegen (pytorch#11227) nomnigraph - fix memory error in NN subgraph matchOp (pytorch#11127) Port PackedSequences functions to C++ (pytorch#11224) Treat numerical differences as warnings instead of errors when tracing (pytorch#11246) add a Float16UniformFill (pytorch#11123) Implement torch.tensordot (pytorch#10025) keep net type info when generating model complete net (pytorch#11032) Get rid of some uses of type() (pytorch#11215) Reorganize methods in Type, add CPUTypeDefault/CUDATypeDefault (pytorch#11205) ...
Summary: Fixes: pytorch#8988 Pull Request resolved: pytorch#10025 Reviewed By: ezyang Differential Revision: D9540967 Pulled By: yf225 fbshipit-source-id: 6ba2a7777162983977db884b693e6f4543b31aeb
Fixes: #8988