-
Notifications
You must be signed in to change notification settings - Fork 26.3k
merge interfaces that have an optional scalartype parameter #21088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@pytorchbot rebase this please |
|
@pytorchbot rebase this please |
|
With latest changes the existing jit tests pass (at least, test_jit.py, when I ran them pre-commit), but it seems to me like I should need to make additional changes to support the dtype arg in |
| def mean_0(self, *, dtype: Optional[int]): | ||
| self_size = self.size() | ||
| self_numel = self.numel() | ||
| self_scalar_type = self.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not correctly unwrap the dtype I believe, self_scalar_type is still an Optional[int], you will probably need to unwrap it using boolean refinement like if is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't trying to unwrap, was just trying to be consistent with how the rest of the code is written. e.g. self.scalar_type() gets converted to self_scalar_type (REPLACEMENTS in def saved_variables in load_derivatives.py).
I think torch.mean() and to() both accept optional. If this is wrong though, how would I write a test that would catch it? It seems like this is working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L2704 hmmm that's pretty weird, from what i see in the schema, to does not take optional dtype, it also does not have default values, so I am not sure why this is working.
For the mean, it looks like you are changing the API, can you add some tests to test the api in eager and see if Autodiff works in JIT? https://github.com/pytorch/pytorch/blob/master/test/common_methods_invocations.py#L339
The tests right now only have one test case, which does not contains the test case with dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I didn't notice this comment. Added a test of ('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}),. Is that what you mean?
passed when I ran locally with:
python test_jit.py TestJitGeneratedAutograd.test_mean_dtype
| *, | ||
| dtype: Optional[int]): | ||
| self_size = self.size() | ||
| self_scalar_type = self.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
| } | ||
| } else if (type_ == ParameterType::SCALARTYPE) { | ||
| if (str == "None") { | ||
| if (str == "None" || str == "c10::nullopt") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why here is c10::nullopt? python arg parser should only consume None in python rather than nullopt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... I hadn't thought about it in detail since this did fix an error I was getting at this point in the code. The default value is set in c++ though right? If the C++ function is defined like Tensor::mean(Tensor self, c10::optional<ScalarType> dtype = c10::nullopt), and the python function call is tensor.mean(), where should the parser find the conversion from c10::nullopt to None?
Is that just supposed to happen based on how it's defined in native_functions.yaml? Somehow the code didn't work without this change.
| // Additionally: | ||
| // - First input should be the only tensor input | ||
| // - has a bool keepdim argument | ||
| static const register_formula_for dim_reduce_ops_with_integer_upcast{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In these register_formula_for entries it looks like the code is explicitly parsing each argument in the registered function. It seems that since I've added a dtype argument, I also need to parse that. I can guess at what the code should look like to do so, but I don't have a failing tests and I'm not sure how to write one.
| node->output()->setType(tp->withSizesStrides(sizes, tp->strides())); | ||
| return true; | ||
| } else if (node->matches("aten::sum(Tensor self) -> Tensor")) { | ||
| } else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly here in PropagateCompleteShapeOnNode, it looks like this code explicitly handles each argument. I think I need to add similar code to extract the dtype, but don't have a failing test and I'm not sure how to write one that will reach this code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/master/test/test_jit.py#L6726 here is an example of how to write shape propagation tests
Jun 04 16:33:39 ERROR: test_backwards (__main__.TestMSNPUTensor) Jun 04 16:33:39 ---------------------------------------------------------------------- Jun 04 16:33:39 Traceback (most recent call last): Jun 04 16:33:39 File "test_cpp_extensions.py", line 670, in test_backwards Jun 04 16:33:39 d = c.sum() Jun 04 16:33:39 RuntimeError: No function registered for schema: sum(Tensor self, ScalarType dtype) -> Tensor
| register_extension_backend_op( | ||
| Backend::MSNPU, | ||
| "sum(Tensor self) -> Tensor", &sum_override); | ||
| "sum(Tensor self, ScalarType dtype) -> Tensor", &sum_override); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this isn't correct since dtype should be optional. That's due to an existing bug, filed here to be fixed separately:
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments on the shape propagation. You can check-in #18813 while you're working on it as well (and maybe I should land it).
| @@ -0,0 +1,13 @@ | |||
| graph(%a : Tensor, | |||
| %b : Tensor): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you find another way to test this other than expect files ? We do not use them. Consider using our filecheck tool as a way of comparing expected textual output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the expect test and added FileCheck tests
| "aten::log2(Tensor self) -> Tensor", | ||
| "aten::log_sigmoid(Tensor self) -> Tensor", | ||
| "aten::log_softmax(Tensor self, int dim) -> Tensor", | ||
| "aten::log_softmax(Tensor self, int dim, *, int? dtype) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above these sets of ops require that the scalar type of the input is preserved. Now that dtype is an argument it is not longer valid to have these ops in this set.
| { | ||
| "aten::sum(Tensor self) -> Tensor", | ||
| "aten::prod(Tensor self) -> Tensor", | ||
| "aten::sum(Tensor self, *, int? dtype) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
| static const register_formula_for dim_reduce_ops_with_integer_upcast{ | ||
| { | ||
| "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor", | ||
| "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
| { | ||
| "aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor", | ||
| "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor", | ||
| "aten::mean(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here as well
| node->output()->setType(tp->withSizesStrides(sizes, tp->strides())); | ||
| return true; | ||
| } else if (node->matches("aten::sum(Tensor self) -> Tensor")) { | ||
| } else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/pytorch/pytorch/blob/master/test/test_jit.py#L6726 here is an example of how to write shape propagation tests
gchanan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving non-JIT changes.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for the effort in updating the JIT. I know it's a tricky part of the codebase to navigate.
Please update the test for pairs of dtypes before landing if that comment applies.
|
|
||
| // Requirements: | ||
| // dims : preserved | ||
| // scalar type : preserved unless specified. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scalar type : preserved unless specified is inaccurate since integer_upcast is true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, fixed
|
|
||
| // Requirements: | ||
| // dims : preserved if keepdim == false, 1 smaller otherwise | ||
| // dims : preserved if keepdim == false, dim.size() smaller otherwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this equivalent to saying preserved if keepdim == false, 0 otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no because dims comes from the dim parameter, not from self. E.g., 3-dimension tensor reduced on two dims gives 1-dim tensor:
>>> torch.ones([3,3,3]).sum([0,1])
tensor([9., 9., 9.])
| if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)): | ||
| if op in ['mean', 'softmax', 'log_softmax']: | ||
| continue | ||
| return_line = "torch.tensor({}, dtype={}).{}({}dtype={})".format(tensor_data, tensor_type, op, str_args, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't some of the ops have different behavior depending on what the inputs tensors dtype is? If so, don't you need to iterate over all pairs of dtypes for the torch.tensor dtype arg and the op dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean, doesn't it? we nested-loop over dtypes twice, once as dtype and once as tensor_type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oop yea nvm, you're doing it already, although could tensor_type could iterate over None too?
test/test_jit.py
Outdated
| self._test_dtype_op_shape(ops, [0, False]) | ||
|
|
||
| ops = ['sum', 'mean'] | ||
| self._test_dtype_op_shape(ops, [[0, 1], False], 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe use kwargs here so it's easier to tell what the arguments mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
|
@pytorchbot retest this please |
|
@pytorchbot rebase this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| ('mean', (), NO_ARGS, 'scalar', (True,)), | ||
| ('mean', (), (0,), 'scalar_dim', (True,), [0]), | ||
| ('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]), | ||
| ('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me. Thanks for adding the test!
Might need to follow up with @ailzhang on the xla test
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Yea I will make followup changes on XLA side. |
Summary: This change is backwards incompatible in *C++ only* on mean(), sum(), and prod() interfaces that accepted either of: ``` Tensor sum(IntArrayRef dim, bool keepdim=false) const; Tensor sum(IntArrayRef dim, ScalarType dtype) const; ``` but now to specify both the dim and dtype will require the keepdim parameter: ``` Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const; ``` [xla ci] Pull Request resolved: pytorch/pytorch#21088 Reviewed By: ailzhang Differential Revision: D15944971 Pulled By: nairbv fbshipit-source-id: 53473c370813d9470b190aa82764d0aea767ed74
|
tests are failing master with errors like: Going to revert this PR. |
Summary: This is (mostly) the re-application of: pytorch/pytorch#21088 which was reverted due to an issue conflicting with changes in: pytorch/pytorch#22104 Pull Request resolved: pytorch/pytorch#22237 Differential Revision: D16012838 Pulled By: nairbv fbshipit-source-id: 35f4a73c97ab68b4e2648aca96b2176f07b5a883
This change is backwards incompatible in C++ only on mean(), sum(), and prod() interfaces that accepted either of:
but now to specify both the dim and dtype will require the keepdim parameter:
[xla ci]