Skip to content

Conversation

@nairbv
Copy link
Collaborator

@nairbv nairbv commented May 29, 2019

This change is backwards incompatible in C++ only on mean(), sum(), and prod() interfaces that accepted either of:

Tensor sum(IntArrayRef dim, bool keepdim=false) const;	
Tensor sum(IntArrayRef dim, ScalarType dtype) const;

but now to specify both the dim and dtype will require the keepdim parameter:

Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;

[xla ci]

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: onnx Related to torch.onnx module: operators module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) labels May 29, 2019
@nairbv
Copy link
Collaborator Author

nairbv commented May 29, 2019

@pytorchbot rebase this please

@nairbv
Copy link
Collaborator Author

nairbv commented May 30, 2019

@pytorchbot rebase this please

@nairbv
Copy link
Collaborator Author

nairbv commented Jun 4, 2019

With latest changes the existing jit tests pass (at least, test_jit.py, when I ran them pre-commit), but it seems to me like I should need to make additional changes to support the dtype arg in torch/csrc/jit/passes/shape_analysis.cpp. If that's correct, I'm not sure what the tests to validate those changes would look like.

@nairbv nairbv requested a review from suo June 4, 2019 19:10
@ailzhang ailzhang self-requested a review June 4, 2019 21:25
@eellison eellison self-requested a review June 4, 2019 21:26
@wanchaol wanchaol self-requested a review June 4, 2019 21:26
def mean_0(self, *, dtype: Optional[int]):
self_size = self.size()
self_numel = self.numel()
self_scalar_type = self.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not correctly unwrap the dtype I believe, self_scalar_type is still an Optional[int], you will probably need to unwrap it using boolean refinement like if is not None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't trying to unwrap, was just trying to be consistent with how the rest of the code is written. e.g. self.scalar_type() gets converted to self_scalar_type (REPLACEMENTS in def saved_variables in load_derivatives.py).

I think torch.mean() and to() both accept optional. If this is wrong though, how would I write a test that would catch it? It seems like this is working.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L2704 hmmm that's pretty weird, from what i see in the schema, to does not take optional dtype, it also does not have default values, so I am not sure why this is working.

For the mean, it looks like you are changing the API, can you add some tests to test the api in eager and see if Autodiff works in JIT? https://github.com/pytorch/pytorch/blob/master/test/common_methods_invocations.py#L339
The tests right now only have one test case, which does not contains the test case with dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I didn't notice this comment. Added a test of ('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}),. Is that what you mean?

passed when I ran locally with:
python test_jit.py TestJitGeneratedAutograd.test_mean_dtype

*,
dtype: Optional[int]):
self_size = self.size()
self_scalar_type = self.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

}
} else if (type_ == ParameterType::SCALARTYPE) {
if (str == "None") {
if (str == "None" || str == "c10::nullopt") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here is c10::nullopt? python arg parser should only consume None in python rather than nullopt

Copy link
Collaborator Author

@nairbv nairbv Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... I hadn't thought about it in detail since this did fix an error I was getting at this point in the code. The default value is set in c++ though right? If the C++ function is defined like Tensor::mean(Tensor self, c10::optional<ScalarType> dtype = c10::nullopt), and the python function call is tensor.mean(), where should the parser find the conversion from c10::nullopt to None?

Is that just supposed to happen based on how it's defined in native_functions.yaml? Somehow the code didn't work without this change.

// Additionally:
// - First input should be the only tensor input
// - has a bool keepdim argument
static const register_formula_for dim_reduce_ops_with_integer_upcast{
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In these register_formula_for entries it looks like the code is explicitly parsing each argument in the registered function. It seems that since I've added a dtype argument, I also need to parse that. I can guess at what the code should look like to do so, but I don't have a failing tests and I'm not sure how to write one.

node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
return true;
} else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
} else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here in PropagateCompleteShapeOnNode, it looks like this code explicitly handles each argument. I think I need to add similar code to extract the dtype, but don't have a failing test and I'm not sure how to write one that will reach this code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jun 04 16:33:39 ERROR: test_backwards (__main__.TestMSNPUTensor)
Jun 04 16:33:39 ----------------------------------------------------------------------
Jun 04 16:33:39 Traceback (most recent call last):
Jun 04 16:33:39   File "test_cpp_extensions.py", line 670, in test_backwards
Jun 04 16:33:39     d = c.sum()
Jun 04 16:33:39 RuntimeError: No function registered for schema: sum(Tensor self, ScalarType dtype) -> Tensor
@pytorchbot pytorchbot added the module: cpp-extensions Related to torch.utils.cpp_extension label Jun 5, 2019
register_extension_backend_op(
Backend::MSNPU,
"sum(Tensor self) -> Tensor", &sum_override);
"sum(Tensor self, ScalarType dtype) -> Tensor", &sum_override);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this isn't correct since dtype should be optional. That's due to an existing bug, filed here to be fixed separately:

#21416

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments on the shape propagation. You can check-in #18813 while you're working on it as well (and maybe I should land it).

@@ -0,0 +1,13 @@
graph(%a : Tensor,
%b : Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you find another way to test this other than expect files ? We do not use them. Consider using our filecheck tool as a way of comparing expected textual output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the expect test and added FileCheck tests

"aten::log2(Tensor self) -> Tensor",
"aten::log_sigmoid(Tensor self) -> Tensor",
"aten::log_softmax(Tensor self, int dim) -> Tensor",
"aten::log_softmax(Tensor self, int dim, *, int? dtype) -> Tensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above these sets of ops require that the scalar type of the input is preserved. Now that dtype is an argument it is not longer valid to have these ops in this set.

{
"aten::sum(Tensor self) -> Tensor",
"aten::prod(Tensor self) -> Tensor",
"aten::sum(Tensor self, *, int? dtype) -> Tensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

static const register_formula_for dim_reduce_ops_with_integer_upcast{
{
"aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

{
"aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::mean(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here as well

node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
return true;
} else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
} else if (node->matches("aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving non-JIT changes.

@nairbv
Copy link
Collaborator Author

nairbv commented Jun 20, 2019

@eellison / @suo / @wanchaol , can one of you approve the jit-related changes here?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for the effort in updating the JIT. I know it's a tricky part of the codebase to navigate.

Please update the test for pairs of dtypes before landing if that comment applies.


// Requirements:
// dims : preserved
// scalar type : preserved unless specified.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar type : preserved unless specified is inaccurate since integer_upcast is true

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, fixed


// Requirements:
// dims : preserved if keepdim == false, 1 smaller otherwise
// dims : preserved if keepdim == false, dim.size() smaller otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this equivalent to saying preserved if keepdim == false, 0 otherwise

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no because dims comes from the dim parameter, not from self. E.g., 3-dimension tensor reduced on two dims gives 1-dim tensor:

>>> torch.ones([3,3,3]).sum([0,1])
tensor([9., 9., 9.])

if(not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point)):
if op in ['mean', 'softmax', 'log_softmax']:
continue
return_line = "torch.tensor({}, dtype={}).{}({}dtype={})".format(tensor_data, tensor_type, op, str_args, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't some of the ops have different behavior depending on what the inputs tensors dtype is? If so, don't you need to iterate over all pairs of dtypes for the torch.tensor dtype arg and the op dtype?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean, doesn't it? we nested-loop over dtypes twice, once as dtype and once as tensor_type.

Copy link
Contributor

@eellison eellison Jun 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oop yea nvm, you're doing it already, although could tensor_type could iterate over None too?

test/test_jit.py Outdated
self._test_dtype_op_shape(ops, [0, False])

ops = ['sum', 'mean']
self._test_dtype_op_shape(ops, [[0, 1], False], 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe use kwargs here so it's easier to tell what the arguments mean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@eellison
Copy link
Contributor

@ailzhang @wanchaol take a look at autograd ?

@nairbv
Copy link
Collaborator Author

nairbv commented Jun 21, 2019

@pytorchbot retest this please

@nairbv
Copy link
Collaborator Author

nairbv commented Jun 21, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

('mean', (), NO_ARGS, 'scalar', (True,)),
('mean', (), (0,), 'scalar_dim', (True,), [0]),
('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]),
('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding test with dtype, @wanchaol / @ailzhang

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me. Thanks for adding the test!

Might need to follow up with @ailzhang on the xla test

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ailzhang
Copy link
Contributor

Yea I will make followup changes on XLA side.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 24, 2019
Summary:
This change is backwards incompatible in *C++ only* on mean(), sum(), and prod() interfaces that accepted either of:
```
Tensor sum(IntArrayRef dim, bool keepdim=false) const;
Tensor sum(IntArrayRef dim, ScalarType dtype) const;
```
but now to specify both the dim and dtype will require the keepdim parameter:
```
Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional<ScalarType> dtype=c10::nullopt) const;
```

[xla ci]
Pull Request resolved: pytorch/pytorch#21088

Reviewed By: ailzhang

Differential Revision: D15944971

Pulled By: nairbv

fbshipit-source-id: 53473c370813d9470b190aa82764d0aea767ed74
@facebook-github-bot
Copy link
Contributor

@nairbv merged this pull request in 142361a.

@suo
Copy link
Member

suo commented Jun 24, 2019

tests are failing master with errors like:

Jun 24 15:42:02 	Schema not found for node. File a bug report.
Jun 24 15:42:02 	Node: %21 : Tensor = aten::cumsum(%18, %3, %31) # /opt/conda/lib/python3.6/site-packages/torch/distributions/transforms.py:516:0

Going to revert this PR.

facebook-github-bot pushed a commit that referenced this pull request Jun 26, 2019
Summary:
This is (mostly) the re-application of:
#21088

which was reverted due to an issue conflicting with changes in:
#22104
Pull Request resolved: #22237

Differential Revision: D16012838

Pulled By: nairbv

fbshipit-source-id: 35f4a73c97ab68b4e2648aca96b2176f07b5a883
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 26, 2019
Summary:
This is (mostly) the re-application of:
pytorch/pytorch#21088

which was reverted due to an issue conflicting with changes in:
pytorch/pytorch#22104
Pull Request resolved: pytorch/pytorch#22237

Differential Revision: D16012838

Pulled By: nairbv

fbshipit-source-id: 35f4a73c97ab68b4e2648aca96b2176f07b5a883
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: autograd Related to torch.autograd, and the autograd engine in general module: bc-breaking Related to a BC-breaking change module: cpp-extensions Related to torch.utils.cpp_extension module: internals Related to internal abstractions in c10 and ATen module: onnx Related to torch.onnx module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants