Skip to content

Conversation

@nairbv
Copy link
Collaborator

@nairbv nairbv commented Jun 26, 2019

Improve handling of mixed-type tensor operations.

This PR affects the arithmetic (add, sub, mul, and div) operators implemented via TensorIterator (so dense but not sparse tensor ops).

For these operators, we will now promote to reasonable types where possible, following the rules defined in #9515, and error in cases where the cast would require floating point -> integral or non-boolean to boolean downcasts.

The details of the promotion rules are described here:
https://github.com/nairbv/pytorch/blob/promote_types_strict/docs/source/tensor_attributes.rst

Some specific backwards incompatible examples:

  • now int_tensor * float will result in a float tensor, whereas previously the floating point operand was first cast to an int. Previously torch.tensor(10) * 1.9 => tensor(10) because the 1.9 was downcast to 1. Now the result will be the more intuitive tensor(19)
  • Now int_tensor *= float will error, since the floating point result of this operation can't be cast into the in-place integral type result.

See more examples/detail in the original issue (#9515), in the above linked tensor_attributes.rst doc, or in the test_type_promotion.py tests added in this PR:
https://github.com/nairbv/pytorch/blob/promote_types_strict/test/test_type_promotion.py

@pytorchbot pytorchbot added module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: operators module: tests Issues related to tests (not the torch.testing module) labels Jun 26, 2019
@nairbv nairbv added the module: bc-breaking Related to a BC-breaking change label Jun 26, 2019
@nairbv nairbv changed the title Promote tensor types Promote tensor types without unsafe operations Jun 26, 2019
Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan's suggestion, which I agree with, is to make int_tensor *= float an error and then see if it triggers any errors in the PyTorch examples and whatever Will is using to test the variable/tensor merge. If there are cases that trigger the error, we should support the "legacy" behavior with a deprecation warning.

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: autograd Related to torch.autograd, and the autograd engine in general module: cpu CPU specific problem (e.g., perf, algorithm) labels Jul 3, 2019
@nairbv
Copy link
Collaborator Author

nairbv commented Jul 8, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

Sorry, I can't merge this because there are conflicts. To merge this yourself, run the commands below:

git fetch origin master
git fetch [email protected]:nairbv/pytorch.git promote_types_strict
git checkout FETCH_HEAD
git merge origin/master
git push [email protected]:nairbv/pytorch.git HEAD:promote_types_strict

(To learn more about this bot, see Bot commands.)

@nairbv
Copy link
Collaborator Author

nairbv commented Jul 9, 2019

@wanchaol / @eellison any thoughts on the shape analysis changes here?

Also @ailzhang, this PR would cause a break in XLA, is it anything that would be difficult to fix?

@ailzhang
Copy link
Contributor

ailzhang commented Jul 9, 2019

@nairbv Thanks for letting us know! To me it sounds like we are changing returned types by type promotion and it shouldn't be hard for XLA to follow our behaviors. cc @dlibenzi @asuhan
@nairbv To make it easier to understand, would you mind adding a summary what exactly this PR do?
e.g. under what condition, what type promotion is triggered.
I think it'll be super helpful for XLA to follow the changes, and understand our overall plan on this. Thanks!

@asuhan
Copy link
Contributor

asuhan commented Jul 9, 2019

@ailzhang @nairbv I don't see major issues on XLA end with doing less implicit conversions. If anything, it was more effort for us to add implicit conversion logic to be as forgiving as other backends (CPU or CUDA), XLA doesn't do implicit conversions on its own.

As far as I can tell, the errors in our integration are in the test framework itself, when we're comparing whether output and expected are close. We can adjust that quite easily.

@nairbv
Copy link
Collaborator Author

nairbv commented Jul 9, 2019

would you mind adding a summary what exactly this PR do? e.g. under what condition, what type promotion is triggered.

updated the description to contain more info from the linked PR and issue, and linked directly to the issue.

This is still a work-in-progress, ironing out details of the change and doing more testing, but we'll let know again before merging.

@nairbv
Copy link
Collaborator Author

nairbv commented Sep 3, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@nairbv
Copy link
Collaborator Author

nairbv commented Sep 3, 2019

@pytorchbot retest this please

@nairbv
Copy link
Collaborator Author

nairbv commented Sep 3, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@pytorchbot pytorchbot added the module: cpp Related to C++ API label Sep 4, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@nairbv
Copy link
Collaborator Author

nairbv commented Sep 5, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

There's nothing to do! This branch is already up to date with master (a3d0abf).

(To learn more about this bot, see Bot commands.)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Sep 6, 2019
Summary:
Improve handling of mixed-type tensor operations.

This PR affects the arithmetic (add, sub, mul, and div) operators implemented via TensorIterator (so dense but not sparse tensor ops).

For these operators, we will now promote to reasonable types where possible, following the rules defined in #9515, and error in cases where the cast would require floating point -> integral or non-boolean to boolean downcasts.

The details of the promotion rules are described here:
https://github.com/nairbv/pytorch/blob/promote_types_strict/docs/source/tensor_attributes.rst

Some specific backwards incompatible examples:
* now `int_tensor * float` will result in a float tensor, whereas previously the floating point operand was first cast to an int. Previously `torch.tensor(10) * 1.9` => `tensor(10)` because the 1.9 was downcast to `1`. Now the result will be the more intuitive `tensor(19)`
* Now `int_tensor *= float` will error, since the floating point result of this operation can't be cast into the in-place integral type result.

See more examples/detail in the original issue (#9515), in the above linked tensor_attributes.rst doc, or in the test_type_promotion.py tests added in this PR:
https://github.com/nairbv/pytorch/blob/promote_types_strict/test/test_type_promotion.py
Pull Request resolved: #22273

Reviewed By: gchanan

Differential Revision: D16582230

Pulled By: nairbv

fbshipit-source-id: 4029cca891908cdbf4253e4513c617bba7306cb3
zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 6, 2019
Summary:
Improve handling of mixed-type tensor operations.

This PR affects the arithmetic (add, sub, mul, and div) operators implemented via TensorIterator (so dense but not sparse tensor ops).

For these operators, we will now promote to reasonable types where possible, following the rules defined in pytorch/pytorch#9515, and error in cases where the cast would require floating point -> integral or non-boolean to boolean downcasts.

The details of the promotion rules are described here:
https://github.com/nairbv/pytorch/blob/promote_types_strict/docs/source/tensor_attributes.rst

Some specific backwards incompatible examples:
* now `int_tensor * float` will result in a float tensor, whereas previously the floating point operand was first cast to an int. Previously `torch.tensor(10) * 1.9` => `tensor(10)` because the 1.9 was downcast to `1`. Now the result will be the more intuitive `tensor(19)`
* Now `int_tensor *= float` will error, since the floating point result of this operation can't be cast into the in-place integral type result.

See more examples/detail in the original issue (pytorch/pytorch#9515), in the above linked tensor_attributes.rst doc, or in the test_type_promotion.py tests added in this PR:
https://github.com/nairbv/pytorch/blob/promote_types_strict/test/test_type_promotion.py
Pull Request resolved: pytorch/pytorch#22273

Reviewed By: gchanan

Differential Revision: D16582230

Pulled By: nairbv

fbshipit-source-id: 4029cca891908cdbf4253e4513c617bba7306cb3
facebook-github-bot pushed a commit that referenced this pull request Sep 20, 2019
Summary:
test_wrapped_number was calling torch.set_default_tensor_type('torch.FloatTensor'), which was setting the default tensor types for all following tests until a class boundary (with unittest) or until end of file (with pytest). Tests that don't expect the default tensor type to be set this way were then failing if run afterwards.

This fixes the issue by copying the default_tensor_type decorator from test_nn and using that instead with test_wrapped_number. The decorator correctly resets the default tensor type after the test has run.

This fixes the many errors encountered when running pytest test_jit.py.

Note: test_wrapped_number was introduced in #22273.
Pull Request resolved: #26523

Differential Revision: D17495283

Pulled By: mruberry

fbshipit-source-id: ab518c78b7706af7cb1c2d1c17823d311178996d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: autograd Related to torch.autograd, and the autograd engine in general module: bc-breaking Related to a BC-breaking change module: cpp Related to C++ API module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries module: tests Issues related to tests (not the torch.testing module) oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.