Skip to content

Conversation

@resistor
Copy link
Contributor

@resistor resistor requested a review from suo June 12, 2019 17:35
@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 12, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@resistor has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does moving all the elements of a and b create a use-after-move issue if a or b is accessed later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for someone to have another reference to them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes it is. Add is a functional operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes it is. Add is a functional operation.

@resistor resistor force-pushed the list branch 2 times, most recently from fbefa94 to 5b0e8b9 Compare June 12, 2019 18:32
@resistor
Copy link
Contributor Author

@suo @driazati Can you look at the new version?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@resistor has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@driazati driazati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, though the resulting ops could be a little simpler by doing the dispatch on use_count at compile time by returning an Operation instead

@resistor
Copy link
Contributor Author

This currently conflicts with #21170 which I think might have the same use-after-free bug as the original version of this...

@Chillee
Copy link
Collaborator

Chillee commented Jun 13, 2019

I assume you meant to link to the c10::List PR and not my math.log standard library PR?

@pytorchbot pytorchbot added the module: internals Related to internal abstractions in c10 and ATen label Jun 25, 2019
@smessmer
Copy link
Contributor

I'd rather have the optimization inside of the List class, then more code could potentially profit from it. Afaik, #21896 is already doing that?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@resistor has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@resistor
Copy link
Contributor Author

Is this going in soon? I have a PR to rebase past it.

facebook-github-bot pushed a commit that referenced this pull request Jun 27, 2019
Summary:
In talks with smessmer, we decided that it'd be better to put the logic in `list`, as optimal behavior requires knowing `.capacity()`

Results on my cpu (for the benchmark here: https://twitter.com/VahidK/status/1138674536679821312) now look like this:
```
Pytorch batch_gather took 0.018311 seconds.
Pytorch batch_gather jit took 0.013921 seconds.
Pytorch vectorized batch_gather took 0.001384 seconds.
```
Previously, `batch_gather jit` took 3x as long as `batch_gather`.

Some logic taken from #21690. Note that these two PR's are somewhat orthogonal. That PR handles this benchmark by looking at the alias analysis, while this PR specializes for `+=`.

Note that we can't jit the vectorized version as we think `torch.arange` returns a float tensor.
Pull Request resolved: #21896

Differential Revision: D15998628

Pulled By: Chillee

fbshipit-source-id: b0085960da4613578b94deb98ac62c0a4532a8c3
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 27, 2019
Summary:
In talks with smessmer, we decided that it'd be better to put the logic in `list`, as optimal behavior requires knowing `.capacity()`

Results on my cpu (for the benchmark here: https://twitter.com/VahidK/status/1138674536679821312) now look like this:
```
Pytorch batch_gather took 0.018311 seconds.
Pytorch batch_gather jit took 0.013921 seconds.
Pytorch vectorized batch_gather took 0.001384 seconds.
```
Previously, `batch_gather jit` took 3x as long as `batch_gather`.

Some logic taken from pytorch/pytorch#21690. Note that these two PR's are somewhat orthogonal. That PR handles this benchmark by looking at the alias analysis, while this PR specializes for `+=`.

Note that we can't jit the vectorized version as we think `torch.arange` returns a float tensor.
Pull Request resolved: pytorch/pytorch#21896

Differential Revision: D15998628

Pulled By: Chillee

fbshipit-source-id: b0085960da4613578b94deb98ac62c0a4532a8c3
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@resistor has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@smessmer smessmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks

for (T b_element : b) {
ret.push_back(std::move(b_element));

if (a.use_count() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should have a List::make_unshared_copy() that ensures that after calling it, use_count == 1 and does a copy if it needs to, or something for this. Feels like this could be useful in other ops as well.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 7cc8f37.

xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary:
In talks with smessmer, we decided that it'd be better to put the logic in `list`, as optimal behavior requires knowing `.capacity()`

Results on my cpu (for the benchmark here: https://twitter.com/VahidK/status/1138674536679821312) now look like this:
```
Pytorch batch_gather took 0.018311 seconds.
Pytorch batch_gather jit took 0.013921 seconds.
Pytorch vectorized batch_gather took 0.001384 seconds.
```
Previously, `batch_gather jit` took 3x as long as `batch_gather`.

Some logic taken from pytorch#21690. Note that these two PR's are somewhat orthogonal. That PR handles this benchmark by looking at the alias analysis, while this PR specializes for `+=`.

Note that we can't jit the vectorized version as we think `torch.arange` returns a float tensor.
Pull Request resolved: pytorch#21896

Differential Revision: D15998628

Pulled By: Chillee

fbshipit-source-id: b0085960da4613578b94deb98ac62c0a4532a8c3
xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
…terpreter. (pytorch#21690)

Summary:
This fixes the JIT performance gap reported in https://twitter.com/VahidK/status/1138677898439561216
Pull Request resolved: pytorch#21690

Differential Revision: D15783709

fbshipit-source-id: 23bb4acda6b60c27e95667e1d53c7d261a87167d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: internals Related to internal abstractions in c10 and ATen oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants