-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Fix bug in exporting node with multiple outputs by scripting #20256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also add example_outputs to the export function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we fix out (such as unpack) instead of directly rerun the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the "fix" in your comment... When exporting ScriptModule we don't run the model, we get out=None. The out here is used to compare against caffe2_out, so it shouldn't hurt running the model if it's ScriptModule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part is not related to the "fix" in the PR title. It is basically updating the current test infra to be able to test ScriptModule models. The actual fix is in torch/csrc/jit/script/init.cpp _propagate_and_assign_input_and_output_shapes
583fb3f to
b5e5b48
Compare
added |
b5e5b48 to
094850f
Compare
e76ae8e to
8eccc05
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
|
@houseroad merged this pull request in 28be521. |
Summary: ~~This is work in progress due to its dependency on multiple pending PRs.~~ - [x] ONNX: Relax constraint on subgraph input/output type & shape check. onnx/onnx#2009 - [x] PyTorch: Add infra to test_pytorch_onnx_caffe2.py to test ScriptModule models. #20256 This PR should partially resolve #17531. However, ideally we shouldn't need to put cast(and reshape) node to help the conversion for loop condition. - Added cast node for condition values before entering loop node. The ONNX spec only accepts Bool type, while in PyTorch if the condition value is an output from other node it could potentially have any integral type. - Tidying up the exported ONNX loop subgraph input type & shape. According to ONNX spec, input "M" is exported as 0-d scalar tensor with type int64. input "Cond" is exported as incomplete tensor of type Bool without shape information. This is because through out the iteration, the rank of condition value is dynamic, either 0-d or 1-d, as long as it holds a single value. Pull Request resolved: #20445 Differential Revision: D15534188 Pulled By: houseroad fbshipit-source-id: d174e778529def05ee666afeee4b8fb27786e320
Summary: - [x] Add tests after #20256 is merged - Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples. - Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass ```_jit_pass_lower_all_tuples```. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in ```_jit_pass_lower_all_tuples```. Handling inputs is easier because all tuple information is encoded within the input tensor type. - Swap the order of ```_jit_pass_lower_all_tuples``` and ```_jit_pass_erase_number_types```. Ops like ```prim::TupleIndex``` relies on index being a scalar. ```_jit_pass_erase_number_types``` will convert these kind of scalars to tensors. Pull Request resolved: #20784 Reviewed By: zrphercule Differential Revision: D15484171 Pulled By: houseroad fbshipit-source-id: 4767a84038244c929f5662758047af6cb92228d3
No description provided.