-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Update ONNX constant folding to support opset 10. #22515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Update ONNX constant folding to support opset 10. #22515
Conversation
1a84d81 to
ad33a87
Compare
|
@houseroad - The CI failure is unrelated. Could you please review when you get a chance. Thanks. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. But in some cases, we probably don't want to throw errors, instead we should just skip the optimization. We don't want users blocked due to some optimization pass. Also shall we add an explicit test for the dynamic slice as well?
test/onnx/test_utility_funs.py
Outdated
|
|
||
| class TestUtilityFuns(TestCase): | ||
| from torch.onnx.symbolic_helper import _export_onnx_opset_version | ||
| opset_version = _export_onnx_opset_version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we just want explicit use opset 9? since we already cache opset_version here, and it won't change in the concrete test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Using explicit opset 9 here.
| assert(inputTensorValues.size() == 1); | ||
| c10::optional<at::Tensor> runTorchSlice_opset9(const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues) { | ||
| assert(inputTensorValues.size() == 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return nullopt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Fixed.
|
|
||
| c10::optional<at::Tensor> runTorchSlice_opset10(const Node* node, | ||
| std::vector<at::Tensor>& inputTensorValues) { | ||
| assert(inputTensorValues.size() > 2 && inputTensorValues.size() < 6); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return nullopt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Fixed.
| std::vector<at::Tensor>& inputTensorValues) { | ||
| assert(inputTensorValues.size() > 2 && inputTensorValues.size() < 6); | ||
| // Checking validity of 'starts' and 'ends' input | ||
| assert(inputTensorValues[1].sizes().size() == 1 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just return nullopt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Fixed.
| return runTorchSlice_opset10(node, inputTensorValues); | ||
| } | ||
| else { | ||
| throw std::runtime_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return nullopt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Fixed.
| for (size_t i = 0; i < inputTensorValues[1].sizes()[0]; ++i) { | ||
| // ONNX slice accepts negative starts and ends values. | ||
| int64_t start = starts_a[i], end = ends_a[i]; | ||
| int64_t axis = axes[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract this out to helper function for the logic of handling negative index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Fixed.
| // constant-based computations/ops into an initializer node. | ||
| void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { | ||
| void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { | ||
| AT_ASSERT(b->param_node()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also want to add a check on opset_version here as well. If it's not supported, skipping should be enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Done.
@houseroad - I agree that we should just skip the optimization instead of throwing errors. I have updated the code for that. I show a warning so that the user is not surprised when they do not see constant folding. Regarding the explicit test for dynamic slice, we are adding a bunch of tests for opset 10 Slice in |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still feel too many error messages, but it's good to land already.
|
Thanks @houseroad ! |
|
@houseroad merged this pull request in 9d11004. |
Currently ONNX constant folding (
do_constant_folding=Truearg intorch.onnx.exportAPI) supports only opset 9 of ONNX. For opset 10, it is a no-op. This change enables ONNX constant folding for opset 10. Specifically there are three main changes:onnx::Sliceop for backend computation during constant folding.test/onnx/test_utility_funs.pyfor multiple opsets (9 and 10).