Skip to content

Conversation

@RockingJavaBean
Copy link
Contributor

@RockingJavaBean RockingJavaBean commented Mar 30, 2021

Related #54261

This PR ports the method_tests() entries of torch.copysign to OpInfo.

While porting the tests, the test_out cases from test_ops.py would fail as the out variant of torch.copysign does not support scalar inputs.

>>> x = torch.randn(2)
>>> y = torch.empty_like(x)
>>> torch.copysign(x, 1.)
tensor([1.4836, 1.2156])
>>> torch.copysign(x, 1., out=y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: copysign(): argument 'other' (position 2) must be Tensor, not float

This PR fixes the tests by adding an overload native_functions entry and re-dispatching scalar inputs to the existing copysign_out function.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 30, 2021

💊 CI failures summary and remediations

As of commit 27eb2e1 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@codecov
Copy link

codecov bot commented Mar 30, 2021

Codecov Report

Merging #54945 (27eb2e1) into master (aeedd5c) will decrease coverage by 0.20%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #54945      +/-   ##
==========================================
- Coverage   77.44%   77.23%   -0.21%     
==========================================
  Files        1893     1893              
  Lines      186472   186477       +5     
==========================================
- Hits       144404   144033     -371     
- Misses      42068    42444     +376     

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 30, 2021

- func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation is

Tensor& copysign_out(const Tensor& self, const Scalar& other, Tensor& result) {
  return at::copysign_out(result, self, wrapped_scalar_tensor(other));
}

which is fully dispatched so CPU, CUDA is overly conservative; CompositeImplicitAutograd would be OK. Even better, though, would be to make the kernel structured (not in this PR though https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md )--if you think you are interested in making it structured, no action necessary here, structured will fix this up later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for reviewing this PR, and I'm really interested in the structured kernel.
A new PR #55040 is created for porting torch.copysign to structured, please kindly take a look.

low=None, high=None,
requires_grad=requires_grad)
else:
return case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really any non-tuple case for you to hit in the examples below, right?

@mruberry is there a more well known function that's supposed to be used in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang The non-tuple cases are hit when constructing args of SampleInput, please kindly refer to https://github.com/pytorch/pytorch/pull/54945/files#r605009953.
I think this branch of checking tuple causes this ambiguity.
It's nearly midnight in my timezone and it took some hours to build PyTorch from the source, I will try improving the readability of this PR tomorrow.

@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2021

Overall this PR looks good but I am not an OpInfo expert so would like @mruberry to take a look


return [SampleInput(_make_case(lhs), args=(_make_case(rhs),))
for lhs, rhs in cases]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides using _make_case(lhs) to generate input tensor of SampleInput, the _make_case(rhs) is used for args as well.
The rhs corresponds to the second item of each tuple element.
Hence its values are 3.14, 0.0, and -0.0 for the corresponding scalar cases.

@RockingJavaBean
Copy link
Contributor Author

This PR is updated with the following changes.

  • drop the changes for copysign.Scalar_out as torch.copysign is being ported to structured in copysign: port to structured kernel #55040.
  • remove the branch that causes ambiguity in sample_inputs_copysign for readability.

I think this PR is ready for another look.

facebook-github-bot pushed a commit that referenced this pull request Apr 1, 2021
Summary:
Related #54945

This PR ports `copysign` to structured, and the `copysign.Scalar` overloads are re-dispatched to the structured kernel.

Pull Request resolved: #55040

Reviewed By: glaringlee

Differential Revision: D27465501

Pulled By: ezyang

fbshipit-source-id: 5cbabfeaaaa7ca184ae0b701b9692a918a90b117
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice simplification, @RockingJavaBean.

@facebook-github-bot
Copy link
Contributor

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in b074a24.

@RockingJavaBean RockingJavaBean deleted the opinfo_copysign branch April 2, 2021 08:55
# broadcast rhs
(_make_tensor(S, S, S), _make_tensor(S, S)),
# broadcast lhs
(_make_tensor(S, S), _make_tensor(S, S, S)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RockingJavaBean @mruberry Any idea how this is working. AFAIK this shouldn't work. Reference: #50747

Surprisingly, I haven't seen any related failed test in master.

But trying to run locally, I am getting error as expected.

============================================================================= FAILURES =============================================================================
________________________________________________ TestCommonCPU.test_variant_consistency_eager_copysign_cpu_float32 _________________________________________________
Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_inplace_broadcast_test/test/test_ops.py", line 306, in test_variant_consistency_eager
    _test_consistency_helper(inplace_samples, inplace_variants)
  File "/home/kshiteej/Pytorch/pytorch_inplace_broadcast_test/test/test_ops.py", line 290, in _test_consistency_helper
    variant_forward = variant(cloned,
RuntimeError: output with shape [5, 5] doesn't match the broadcast shape [5, 5, 5]
_______________________________________________ TestCommonCUDA.test_variant_consistency_eager_copysign_cuda_float32 ________________________________________________
Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_inplace_broadcast_test/test/test_ops.py", line 306, in test_variant_consistency_eager
    _test_consistency_helper(inplace_samples, inplace_variants)
  File "/home/kshiteej/Pytorch/pytorch_inplace_broadcast_test/test/test_ops.py", line 290, in _test_consistency_helper
    variant_forward = variant(cloned,
RuntimeError: output with shape [5, 5] doesn't match the broadcast shape [5, 5, 5]
========================================================================= warnings summary =========================================================================

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While trying to see the CI build log, I noticed something odd.

Looking at the CI run for pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1.

There is no result for test_variant_consistency_eager_copysign_cpu_float32.

Relevant lines from the log linked below

Details
Apr 01 17:21:54   test_variant_consistency_eager_broadcast_to_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_broadcast_to_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_ceil_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_cholesky_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.006s)
Apr 01 17:21:54   test_variant_consistency_eager_cholesky_cpu_float32 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_cholesky_inverse_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.003s)
Apr 01 17:21:54   test_variant_consistency_eager_cholesky_inverse_cpu_float32 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_clamp_cpu_float32 (__main__.TestCommonCPU) ... ok (0.006s)
Apr 01 17:21:54   test_variant_consistency_eager_conj_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_conj_cpu_float32 (__main__.TestCommonCPU) ... ok (0.003s)
Apr 01 17:21:54   test_variant_consistency_eager_cos_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_cos_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_cosh_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_cosh_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_cummax_cpu_float32 (__main__.TestCommonCPU) ... ok (0.003s)
Apr 01 17:21:54   test_variant_consistency_eager_cummin_cpu_float32 (__main__.TestCommonCPU) ... ok (0.003s)
Apr 01 17:21:54   test_variant_consistency_eager_cumprod_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.007s)
Apr 01 17:21:54   test_variant_consistency_eager_cumprod_cpu_float32 (__main__.TestCommonCPU) ... ok (0.006s)
Apr 01 17:21:54   test_variant_consistency_eager_cumsum_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_cumsum_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_deg2rad_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_diag_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_diag_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 17:21:54   test_variant_consistency_eager_diff_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_diff_cpu_float32 (__main__.TestCommonCPU) ... ok (0.005s)
Apr 01 17:21:54   test_variant_consistency_eager_digamma_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)

@mruberry Did I miss something?

https://circleci.com/api/v1.1/project/github/pytorch/pytorch/12041172/output/107/0?file=true&allocation-id=6065fe626bfe87630fa4b686-0-build%2F48AA87AC

Copy link
Collaborator

@kshitij12345 kshitij12345 Apr 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line causes the issue.

variants = (v for v in (method, inplace) + aliases if v is not None)

variants is a generator so it is exhausted in the first run (refer python snippet below)

Thus the following loop does not run except for the first sample

pytorch/test/test_ops.py

Lines 268 to 286 in c821b83

# Test eager consistency
for variant in variants:
# Skips inplace ops
if variant in inplace_ops and skip_inplace:
continue
# Compares variant's forward
# Note: copies the to-be-modified input when testing the inplace variant
tensor.grad = None
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
self.assertEqual(expected_forward, variant_forward)
# Compares variant's backward
if expected_grad is not None and (variant not in inplace_ops or op.supports_inplace_autograd):
variant_forward.sum().backward()
self.assertEqual(expected_grad, tensor.grad)

Example

>>> def print_iterable(iterable):
...     for x in iterable:
...             print(x)
... 
>>> l = [x for x in range(3)] # list comprehension
>>> print_iterable(l)
0
1
2
>>> print_iterable(l)
0
1
2
>>> l = (x for x in range(3)) # generator expression
>>> print_iterable(l)
0
1
2
>>> print_iterable(l)
>>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While trying to see the CI build log, I noticed something odd.

Looking at the CI run for pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1.

There is no result for test_variant_consistency_eager_copysign_cpu_float32.

Relevant lines from the log linked below

@mruberry Did I miss something?

https://circleci.com/api/v1.1/project/github/pytorch/pytorch/12041172/output/107/0?file=true&allocation-id=6065fe626bfe87630fa4b686-0-build%2F48AA87AC

@kshitij12345 Thank you so much for pointing this out.

According to pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1 triggered by the last commit of this PR 27eb2e1, the test_variant_consistency_eager_copysign_cpu did run and pass.

https://circleci.com/api/v1.1/project/github/pytorch/pytorch/12027499/output/107/0?file=true&allocation-id=606569f28b208810a6a496b9-0-build%2F5C11284D

Apr 01 06:48:33   test_variant_consistency_eager_conj_cpu_float32 (__main__.TestCommonCPU) ... ok (0.003s)
Apr 01 06:48:33   test_variant_consistency_eager_copysign_cpu_float32 (__main__.TestCommonCPU) ... ok (0.004s)
Apr 01 06:48:33   test_variant_consistency_eager_cos_cpu_complex64 (__main__.TestCommonCPU) ... ok (0.004s)

And after checking the test_variant_consistency_eager in test_ops.py, I agree with you that the exhausted generator for variants leads to the skipping of SampleInputs and it is why cases for broadcasted self tensor pass without resolving issue 50747.

I'm really appreciated your PR #53014 for fixing this issue, I think it will enable testing broadcasted self tensor with OpInfo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RockingJavaBean
Right. My bad. I wonder why the log I checked did not have it.

But I can confirm it is running.

Apologies for the false alarm.

Thanks for looking into it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this exposes a horrible problem with the use of generator expressions like this in sample inputs; many tests are "running" but not actually performing the test properly.

@kshitij12345, @vfdev-5, what's the best way to fix this and validate the fix so the behavior isn't regressed later? One simple approach might be to validate that len(sample_inputs) > 0 in all the tests that enumerate them for now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do len on generators. So I think right now we can just do variants=tuple(genearator expression). But how to detect so that this doesn't happen in general case is worth some thought 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... maybe we'll just have to be vigilant to detect it. But it'd be nice if we could not only detect that a test ran but that it actually did something.

Copy link
Collaborator

@kshitij12345 kshitij12345 Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think looking at answers for this question at Stack Overflow might give us a nice idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants