-
Notifications
You must be signed in to change notification settings - Fork 26.3k
cdist: pairwise distances between two sets of tensors with batch mode #20934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
aten/src/ATen/native/Distance.cpp
Outdated
| tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); | ||
|
|
||
| int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies<int64_t>()); | ||
| std::vector<int64_t> tensor1_view({expand_batch_product}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can replace these three lines with std::vector<int64_t> tensor1_view {expand_batch_product, r1, c1};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
aten/src/ATen/native/Distance.cpp
Outdated
| auto dim1 = x1.dim(); | ||
| auto dim2 = x2.dim(); | ||
|
|
||
| IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(dim1 - 2, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some comments to this bunch of code to explain what you are doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are checking above that the dims are >= 2, so why do we need the std::max calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view); | ||
| Tensor tensor2_expanded = x2.expand(tensor2_expand_size).contiguous().view(tensor2_view); | ||
|
|
||
| std::vector<int64_t> output_shape(expand_batch_portion); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do sometimes you use foo.insert(foo.end(), {bar, quux}); and sometimes you use foo.push_back twice? Should be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made it consistent
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ifedan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
… (#20934) Summary: Batch implementation for cdist function Pull Request resolved: pytorch/pytorch#20934 Differential Revision: D15609458 Pulled By: ifedan fbshipit-source-id: 31c12e120d168baec6a6af913f599838a44034d7
Batch implementation for cdist function