Skip to content

Conversation

@krshrimali
Copy link
Contributor

@krshrimali krshrimali commented Sep 10, 2021

This PR attempts to port baddbmm and bmm to structured kernels. The reason it's in the same PR: because a lot of it is common for both the ops, including the checks and implementation.

Issue tracker: #55070

cc: @ysiraichi @ezyang

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Sep 10, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 10, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit eca37b4 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@krshrimali krshrimali requested review from ysiraichi and removed request for IvanYashchuk, lezcano and nikitaved September 10, 2021 07:53
@krshrimali krshrimali marked this pull request as draft September 10, 2021 10:12
@krshrimali
Copy link
Contributor Author

A few failures are there, working on them. I will make it ready for review once they are resolved.

@krshrimali krshrimali removed the request for review from ezyang September 10, 2021 10:14
Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @krshrimali. That's one nice looking PR. I left some minor comments and some other observations. Check if they make sense, and let me know if there's anything unclear!

Comment on lines 101 to 103
if (beta.to<c10::complex<double>>() != 0.0) {
result.copy_(self);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this belongs to the IMPL function (it's neither a check nor output metadata setting).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, you would have to replicate it to the CUDA IMPL function, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it's here is that the size of self is modified using expand_size (not in-place though, we create another tensor and then use it from there for the checks).

Comment on lines 88 to 92
TORCH_META_FUNC(bmm)(const Tensor& self, const Tensor& mat2) {
set_output({self.sizes()[0], self.sizes()[1], mat2.sizes()[2]}, self.options());
auto& result = maybe_get_output(0);
common_checks_baddbmm_bmm(*this, self, mat2, Scalar(0.0), Scalar(1.0), true, result);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! It got brief and straight-forward.

Tensor& baddbmm_out_cpu(const Tensor& self_, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor &result) {
auto self = expand_size(self_, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
result.resize_(self->sizes());
result.copy_(*self);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we are losing this copy. But, somehow all the tests are passing.

Copy link
Contributor Author

@krshrimali krshrimali Sep 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it before, and I gave it a thought again.

My observation is, expand_size(self_, {batch1.size(0), batch1.size(1), batch2.size(2)}) calls self_.expand method (if required).

Now the thing is, if result tensor's size is already (or if we set explicitly using set_output): {batch1.size(0), batch1.size(1), batch2.size(2)} - and self is not of the required size, then copy_ will resize the self tensor to result tensor's size.

In [17]: M = torch.randn(10, 1, 5)
In [18]: _out = torch.empty((10, 3, 5))

In [20]: _out.copy_(M).size()
Out[20]: torch.Size([10, 3, 5])

And we do this copy_ in the impl function now (after your suggestion in the review comment above, earlier it was in the meta function) if result is not same as self.

Now, result will be same as self if:

  1. It's an in-place call (auto-generated code sets result as self in the meta function itself).
  2. out is self (like: torch.baddbmm(M, batch1, batch2, out=M)). -- I'm not resizing for this case, and that's where it's wrong, please see below

On upstream:

In [43]: M = torch.randn(10, 1, 5)

In [44]: torch.baddbmm(M, batch1, batch2, out=M).size()
Out[44]: torch.Size([10, 3, 5])

On this branch:

In [8]: M = torch.randn(10, 1, 5)

In [9]: torch.baddbmm(M, batch1, batch2, out=M).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-4d21e89d2e08> in <module>
----> 1 torch.baddbmm(M, batch1, batch2, out=M).size()

RuntimeError: Expected self_sizes[0] == bs && self_sizes[1] == res_rows && self_sizes[2] == res_cols to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Now, interesting case is for in-place, this copy and expansion of size doesn't happen in upstream (rightly so). The problem is to find a way if the meta function is called from in-place dispatch, or method/function dispatch. Not sure how to do this :/ Will have to ponder on this one.

In [42]: M.baddbmm_(batch1, batch2).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-42-ffc510337fae> in <module>
----> 1 M.baddbmm_(batch1, batch2).size()

RuntimeError: Expected self_sizes[0] == bs && self_sizes[1] == res_rows && self_sizes[2] == res_cols to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

We need to test this behavior as well in test_torch.py. I'll try to send a fix for this one soon.

Thanks for asking this question, @ysiraichi - sorry for the long message. ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm stuck here:

How would you differentiate these 2 calls (see below) from meta function? (Note: please see the comment above on why we need to differentiate b/w these calls)

>>> M.baddbmm_(batch1, batch2)  # [1] In-Place
>>> torch.baddbmm(M, batch1, batch2, out=M)  # [2] out=self tensor

From the auto-generated code for in-place variant in build/aten/src/ATen/RegisterCPU.cpp:

struct structured_baddbmm_out_cpu_inplace final : public at::native::structured_baddbmm_out_cpu {
     structured_baddbmm_out_cpu_inplace(Tensor& self) : outputs_{std::ref(self)} {}
     // ...
 };

As outputs_ is initialized with std::ref(self) so result.is_same(self) is always gonna be true.

cc: @ezyang @bdhirsh @ysiraichi

Similarly for [2] out=M call, result.is_same(self) will be true. How should we differentiate b/w these calls?

Copy link
Contributor

@bdhirsh bdhirsh Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this actually seems more in-line with our out= behavior: for most other out= ops, if you pass in an output tensor with an incorrect size we raise an error instead of transparently resizing the output for you (unless the tensor has zero size).

I'm just not sure if we want a deprecation cycle, since it's not clear to me how commonly this behavior is relied on (if I had to guess though, it seems like a pretty uncommon use case?). I also don't think you have an easy way with structured kernels to differentiate those two cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I encountered this behaviour in a test in test/test_linalg.py the other day:

out = torch.zeros(*shape, dtype=torch.int64).to(x.device)

Note that the shape is not being passed correctly:

pytorch/test/test_linalg.py

Lines 4988 to 4995 in 88032d8

n = 1
for m in range(1, 8):
for p in range(1, 8):
for o in range(1, 5):
# 1d, 3d, inner dimensions C
x = torch.arange(m, device=device)
y = torch.arange(o * m * p, device=device).reshape(o, m, p)
self.check_single_matmul(x, y, (o, n, p))

This will be fixed in #64387 , but it signals that there may be other places where this behaviour is used.

Note that, at the moment, if an incorrectly shaped out is used, a warning is issued. As such, I think that just making that warning into an error may be a reasonable behaviour here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the valuable inputs, @lezcano, and @bdhirsh. Following the suggestions, we now raise an error for incorrectly shaped out (earlier we allowed resizing). However, this behavior is only for baddbmm. We resize the output for bmm, there is a test that uses incorrect shapes for out here: https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L8394.

The only thing that concerns me is the error messages for inplace call and out=M call. The error with in-place call should have self_sizes mentioned instead of result_sizes, but that will need distinction b/w the 2 calls.

>>> M = torch.randn(10, 1, 5)  # Incorrect shape
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)

>>> M.baddbmm_(batch1, batch2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

>>> torch.baddbmm(M, batch1, batch2, out=M)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expected result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

The tests might need another look, which I personally feel can be done in a follow-up PR.

Copy link
Collaborator

@ysiraichi ysiraichi Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you differentiate these 2 calls (see below) from meta function? (Note: please see the comment above on why we need to differentiate b/w these calls)

While I agree with @bdhirsh (i.e. you can't), you can resize the output tensors conditionally by using set_output. So, instead of writing:

auto& result = meta.maybe_get_output(0);
if (is_bmm || !result.defined()) {
  meta.set_output({bs, res_rows, res_cols}, batch1.options());
} else {
  const auto result_sizes = result.sizes();
  TORCH_CHECK(result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols);
}

You could write:

// 'set_output' does not resize for in-place calls
meta.set_output({bs, res_rows, res_cols}, batch1.options());
// Error is raised if called from the in-place overload with incorrect self shape
TORCH_CHECK(result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang @bdhirsh @lezcano @krshrimali
As far as I remember, the idea was to turn the warning into an error (as @lezcano pointed out) in the future, by using resize_output function (for consistency). Since that function is called by set_output, maybe we could rely on that for now (instead of directly raising an error for the out-of-place). Not sure, though. What do you think?

// TODO: make all operations that resize given outputs use this function
// for consistency and maintainability.
// Some operations like `cat` might not be able to make the use of
// resize_output directly. For more details to understand how it works in `cat`,
// see https://github.com/pytorch/pytorch/pull/62560#discussion_r687363362
// Resizes outputs
// Functions accepting output tensors, like with the "out" kwarg, should
// call this function to handle resizing their output tensor.
// Issues a warning if the output tensor has one or more elements and
// needs resizing
// NOTE: In the future the warning will become an error
// Returns a bool saying whether or not the resize actually happened or not
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysiraichi good catch, that seems like it should work (and also raise the warning). It's probably worth quickly confirming that with that code change, we now get the deprecation warning:

>>> M = torch.randn(10, 1, 5)  # Incorrect shape
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)

>>> torch.baddbmm(M, batch1, batch2, out=M)
# (we should see a deprecation warning, because meta.set_output calls at::native::set_output)

Copy link
Contributor Author

@krshrimali krshrimali Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool, @ysiraichi! I wasn't aware that meta.set_output() doesn't resize for in-place calls. This helps!

Here is how the use-cases look like with the revised changes: (cc: @bdhirsh)

In [11]: M = torch.randn(10, 1, 5)
In [12]: batch1 = torch.randn(10, 3, 4)
In [13]: batch2 = torch.randn(10, 4, 5)

In [14]: torch.baddbmm(M, batch1, batch2).size()
Out[14]: torch.Size([10, 3, 5])

In [15]: M.baddbmm_(batch1, batch2).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-ffc510337fae> in <module>
----> 1 M.baddbmm_(batch1, batch2).size()

RuntimeError: Expected result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [16]: torch.baddbmm(M, batch1, batch2, out=M).size()
<ipython-input-16-4d21e89d2e08>:1: UserWarning: An output with one or more elements was resized since it had shape [10, 1, 5], which does not match the required output shape [10, 3, 5].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at  ../aten/src/ATen/native/Resize.cpp:16.)
  torch.baddbmm(M, batch1, batch2, out=M).size()
Out[16]: torch.Size([10, 3, 5])

We do see a deprecation warning with out=M, and we also get an error for the in-place call. I've made the required commits, the tests should pass and this PR should hopefully look good now. Welcoming reviews now. :)

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice job, @krshrimali
@ezyang could you take a look at it?

meta.set_output({bs, res_rows, res_cols}, batch1.options());
const auto result_sizes = result.sizes();
// Error is raised if called from in-place overload with incorrect shape
TORCH_CHECK(result_sizes[0] == bs && result_sizes[1] == res_rows && result_sizes[2] == res_cols);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: an error message here would be nice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to spell it out individually, ArrayRef has equality defined. This also solves a bug where you can OOB access if the dimensionality differs. Also another one for #57827

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysiraichi maybe you can fix the underlying issue, so we don't have to keep pasting this in? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, to me figuring out the error message gets tricky here (but I've still made an effort to be as correct as possible, in the recent commit) because an error is raised if:

  1. out tensor is passed with an incorrect shape.
  2. (in-place call), input tensor is passed with an incorrect shape.

I've added an error message in the recent commit, please let me know if it sounds good or if you have any suggestions. However, we'll not reach this error frequently, a lot of times it will fail at expand_size:

In [8]: M = torch.randn(1, 3, 2)

In [9]: M.baddbmm_(batch1, batch2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-a9e665d29b16> in <module>
----> 1 M.baddbmm_(batch1, batch2)

RuntimeError: The expanded size of the tensor (5) must match the existing size (2) at non-singleton dimension 2.  Target sizes: [10, 3, 5].  Tensor sizes: [1, 3, 2]

The only time when we'll get this error is for in-place calls, as expected:

In [12]: M = torch.randn(10, 1, 5)

# batch1 size: (10, 3, 4) and batch2 size: (10, 4, 5)
In [13]: M.baddbmm_(batch1, batch2).size()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-13-ffc510337fae> in <module>
----> 1 M.baddbmm_(batch1, batch2).size()

RuntimeError: Expected an output tensor with shape [10, 3, 5] but got shape [10, 1, 5]

Hope this error message looks good now. For completeness, I've also added a similar error message for self_sizes check in the same function. :)

Copy link
Contributor Author

@krshrimali krshrimali Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to spell it out individually, ArrayRef has equality defined. This also solves a bug where you can OOB access if the dimensionality differs. Also another one for #57827

Sorry, I just noticed this. This should be fixed now, defined output_size (IntArrayRef) since it's used twice.

Also added error messages here:

# batch2 is of incorrect size
In [7]: M.baddbmm_(batch1, batch2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-a9e665d29b16> in <module>
----> 1 M.baddbmm_(batch1, batch2)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [10, 4] but got: [10, 5].

Relevant commit

cc: @ezyang @ysiraichi

@ezyang
Copy link
Contributor

ezyang commented Sep 23, 2021

after final review round this looks good to go

@facebook-github-bot
Copy link
Contributor

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@krshrimali
Copy link
Contributor Author

after final review round this looks good to go

Thanks, @ezyang! I have added error messages as suggested by @ysiraichi. One thing I missed was removing the out skips from OpInfo, which is now taken care of in: 3ea56a9 commit.

Thanks for the patience, and everyone for the reviews.

@krshrimali
Copy link
Contributor Author

Update: working on resolving the test failures.

@facebook-github-bot
Copy link
Contributor

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: structured kernels Related to new structured kernels functionality oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants