-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make tril_ and triu_ actually in-place #17031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I don't think this fix is still right. If I do an inplace operation on a view of some big tensor, I expect the inplace to be reflected in the appropriate memory locations on the big tensor. That's not what this patch does. What you should do is compute it out of place, and then do a copy into the old tensor to write it back. |
|
agree with @ezyang. |
|
Alternatively, it shouldn't be that hard to fix the kernels to work with arbitrarily strided tensors. |
| "x_nc should remain non-contiguous") | ||
| elif s < -3: | ||
| self.assertTrue(x_nc.is_contiguous(), | ||
| "x_nc should become contiguous") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since making a non-contiguous tensor contiguous is always guaranteed to be an out-of-place operation, which breaks the in-place property of triu_ and tril_, we should never expect x_nc to be made contiguous in triu_ and tril_.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| apply_triu_tril<scalar_t, true, false>(self, self, k); | ||
| }); | ||
| if (checkTrilTriuBatchContiguous(self)) { | ||
| AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think you can avoid the calling the dispatch twice (because remember, it will expand into some large thing), by just settings result and e.g. self_contiguous using a ternary based on checkTrilTriuBatchContiguous and then copying at the end if necessary.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| } | ||
| if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous(); | ||
| bool inplace = checkTrilTriuBatchContiguous(self); | ||
| Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use inplace instead of checking again?
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Currently, when the input tensor `self` is not contiguous, `tril_` and `triu_` calls `self = self.contiguous()`, which allocates a new contiguous tensor and assign it to `self`. This effectively changes the input tensor `self`'s pointer and will break downstream code after Variable/Tensor merge. This PR fixes it so that `tril_` and `triu_` always update the input tensor in-place and preserve the input tensor's TensorImpl. Pull Request resolved: pytorch/pytorch#17031 Differential Revision: D14069592 Pulled By: yf225 fbshipit-source-id: d188218f426446a44ccc1d33fc28ac3f828c6a05
Currently, when the input tensor
selfis not contiguous,tril_andtriu_callsself = self.contiguous(), which allocates a new contiguous tensor and assign it toself. This effectively changes the input tensorself's pointer and will break downstream code after Variable/Tensor merge.This PR fixes it so that
tril_andtriu_always update the input tensor in-place and preserve the input tensor's TensorImpl.