-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] add pass for onnx scalar type conversion #24378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mapping maybe useful beyond this pass. We should consider abstracting it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this one too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a method for extending type promotion beyond pairs, to a vector of types. It may be useful beyond this pass. Should we maybe move this as a overload to c10::promoteTypes? @bddppq What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9665beb to
207cc26
Compare
|
@BowenBao and @spandantiwari i cloned the following repo mentioned in this issue https://github.com/BowenBao/pytorch/tree/onnx_implicit_cast. Still i am getting the type mismatch error. |
|
Hi @ajaysg, on my local machine I could not repro the mismatch error with this branch. Particularly, the exported model now has a cast node after Mul which should resolve this issue. Please also note that this model won't be able to export, or run in ONNX backend, until the spec & impl for |
ffc18d5 to
0562f98
Compare
|
@pytorchbot retest this please |
|
hi @BowenBao i cloned this repo using the following command |
|
hi @ajaysg , could you try |
|
Hi @BowenBao i tried the same you suggested. Still i am not able to find the cast node after mul node in the generated onnx model. Also while i am using that model in onnxrt i am getting the Type Parameter issue |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dzhulgakov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@pytorchbot retest this please |
|
I'm not the best person to review it probably. @nairbv or @ZolotukhinM - do you happen to have more context on type promotion logic? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is more fundamental solution for scalar type mismatching solution.
I am not sure whether we have covered all the pytorch ops already or not, but it's good start.
Please address my inline comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not only Analysis, but Propagation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As below, the pass is not doing Propagation at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add some comments about the purpose/motivation of this pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit weird we only have one function call in ScalarTypeAnalysisForONNX, consider merging them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were wondering if in the future this is a good place to extend the pass to perform complete shape and type inferencing. Especially to support scripting, which unlike tracing, doesn't record operator outputs with complete tensor type.
There is existing shape_analysis.cpp, but is for aten operators, and it is usually more dynamic than ONNX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give some context why handling like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding it to the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may incur some duplicate values, but I guess it should be rare.
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The IR pass part looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this works correctly if a node has the same value as input several times, e.g
%a = prim::mul(%b, %b)
I'm not sure that inputs() iterator will stay valid after we call replaceInputWith, which would replace all entries of the current input. Could you please double-check that it works as expected? A safe alternative to achieve the same and avoid iterator issues for certain would be to use replaceInput(index, newValue).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very interesting case. I will create a test case for that, after fixing the regression #26328.
resolve & update test issues
|
After rebasing, the CI is failing on this test case #26328, which seems to be a regression since it can also repro on master. I'm currently looking at that. |
0562f98 to
42db938
Compare
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch#24378 * Fix UTs ghstack-source-id: 4e796b1 Pull Request resolved: pytorch#69548
…JIT pass (#69227)" * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
…JIT pass (#69227)" * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch#24378 * Fix UTs ghstack-source-id: 4e796b1 Pull Request resolved: pytorch#69548
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch#24378 * Fix UTs ghstack-source-id: 4e796b1 Pull Request resolved: pytorch#69548
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch#24378 * Fix UTs ghstack-source-id: 4e796b1 Pull Request resolved: pytorch#69548
…JIT pass (#69227)" * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
* Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Differential Revision: [D32994268](https://our.internmc.facebook.com/intern/diff/D32994268) [ghstack-poisoned]
Summary: Pull Request resolved: #69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63
Summary: Pull Request resolved: #69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR #24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322c)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)
Summary: Pull Request resolved: pytorch/pytorch#69548 * Add Concat to Scalar type analysis pass By using scalar type analysis for Concat, the exported model can do automatic type promotion for Concat nodes, including mixed fp16 and fp32 inputs, for example. Unit tests based on the original PR pytorch/pytorch#24378 * Fix UTs Test Plan: Imported from OSS Reviewed By: msaroufim Differential Revision: D32994268 Pulled By: malfet fbshipit-source-id: 0deab88b0bb1e396770690af27730accb64fcf63 (cherry picked from commit a99322cadf7b79a4548266a9d4d3af094b89bac4)

This pass tries to resolve scalar type mismatch issues between input tensors introduced by the implicit type conversions on scalars.
e.g. #23724