Skip to content

Conversation

@spandantiwari
Copy link

Currently ONNX constant folding (do_constant_folding=True arg in torch.onnx.export API) supports only opset 9 of ONNX. For opset 10, it is a no-op. This change enables ONNX constant folding for opset 10. Specifically there are three main changes:

  1. Turn on constant folding ONNX pass for opset 10.
  2. Update support for opset 10 version of onnx::Slice op for backend computation during constant folding.
  3. Enable constant folding tests in test/onnx/test_utility_funs.py for multiple opsets (9 and 10).

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: onnx Related to torch.onnx labels Jul 3, 2019
@spandantiwari spandantiwari requested a review from houseroad July 3, 2019 23:32
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2019
@spandantiwari spandantiwari force-pushed the spandantiwari/constant_folding_opset10 branch from 1a84d81 to ad33a87 Compare July 8, 2019 22:32
@spandantiwari
Copy link
Author

@houseroad - The CI failure is unrelated. Could you please review when you get a chance. Thanks.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. But in some cases, we probably don't want to throw errors, instead we should just skip the optimization. We don't want users blocked due to some optimization pass. Also shall we add an explicit test for the dynamic slice as well?


class TestUtilityFuns(TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we just want explicit use opset 9? since we already cache opset_version here, and it won't change in the concrete test cases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Using explicit opset 9 here.

assert(inputTensorValues.size() == 1);
c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
assert(inputTensorValues.size() == 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return nullopt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Fixed.


c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node,
std::vector<at::Tensor>& inputTensorValues) {
assert(inputTensorValues.size() > 2 && inputTensorValues.size() < 6);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return nullopt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Fixed.

std::vector<at::Tensor>& inputTensorValues) {
assert(inputTensorValues.size() > 2 && inputTensorValues.size() < 6);
// Checking validity of 'starts' and 'ends' input
assert(inputTensorValues[1].sizes().size() == 1 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return nullopt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Fixed.

return runTorchSlice_opset10(node, inputTensorValues);
}
else {
throw std::runtime_error(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return nullopt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Fixed.

for (size_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) {
// ONNX slice accepts negative starts and ends values.
int64_t start = starts_a[i], end = ends_a[i];
int64_t axis = axes[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract this out to helper function for the logic of handling negative index?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Fixed.

// constant-based computations/ops into an initializer node.
void ConstantFoldONNX(Block* b, ParamMap& paramsDict) {
void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) {
AT_ASSERT(b->param_node());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to add a check on opset_version here as well. If it's not supported, skipping should be enough.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Done.

@spandantiwari
Copy link
Author

Overall looks good. But in some cases, we probably don't want to throw errors, instead we should just skip the optimization. We don't want users blocked due to some optimization pass. Also shall we add an explicit test for the dynamic slice as well?

@houseroad - I agree that we should just skip the optimization instead of throwing errors. I have updated the code for that. I show a warning so that the user is not surprised when they do not see constant folding. Regarding the explicit test for dynamic slice, we are adding a bunch of tests for opset 10 Slice in test_pytorch_onnx_onnxruntime.py and the explicit test will be added in that PR, which will be coming soon.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still feel too many error messages, but it's good to land already.

@spandantiwari
Copy link
Author

Thanks @houseroad !

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in 9d11004.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants