-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] export group_norm #27071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] export group_norm #27071
Conversation
torch/onnx/symbolic_opset9.py
Outdated
| @parse_args('v', 'i', 'v', 'v', 'f', 'i') | ||
| def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): | ||
| input_sizes = input.type().sizes() | ||
| shape = [1, input_sizes[0] * num_groups, input_sizes[1] / num_groups, -1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact it might be possible to remove the dependency on explicit input_sizes.
- For the first reshape, reshape to [0, num_groups, -1], the effect will be [batch, channel, h, w] => [batch, num_groups, channel / num_groups * h * w].
- For
weightandbias, consider the below comment regarding broadcasting. - For the second reshape, just reshape back to the original input shape:
g.op('Reshape', norm_reshaped, g.op("Shape", input)). - For the last
weightandbias, it seems we still need the rank information of input. We can remove this dependency as well if we add two more reshapes to the norm tensor, but that might not worth the effort as rank info is usually static.
The motivation is that we won't be able to support dynamic input sizes if we depend on static input_sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BowenBao, apparently broadcasting is not supported for InstanceNormalization.
So I don't think we can avoid accessing the size for weight_ and bias_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but I think we can still construct weight_ and bias_ with just num_groups, if in the first reshape we keep batch and channel axes separated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I gonna update the first two parts anyways. But I agree that it might not be worth it to add two reshapes since rank here could be N (it's not limited in spec). I'll leave a note in case anyone needed this op with dynamic rank in future.
|
@pytorchbot rebase this please |
BowenBao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please always remember to update test operator expect files if symbolic fn is updated.
torch/onnx/symbolic_opset9.py
Outdated
|
|
||
| @parse_args('v', 'i', 'v', 'v', 'f', 'i') | ||
| def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): | ||
| if input.isCompleteTensor(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isCompleteTensor can be relaxed. We only require its scalar type and rank to be known.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I'm on console access to the environment and build is slow, so I updated the files here.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
cc @houseroad tests are fixed. Please re-import. Thanks. |
|
cc @houseroad for review |
1 similar comment
|
cc @houseroad for review |
|
@houseroad please review updates. |
|
@pytorchbot retest this please |
|
cc @houseroad for review |
|
@houseroad for review |
|
@pytorchbot rebase this please |
|
@houseroad please review the updates |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This broke some internal use cases, I need further investigation.
|
Thanks @houseroad for looking into this. Is there a failure regarding group_norm? |
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's check in this for now. However, this translation is really inefficient. We should introduce GroupNorm to ONNX, so we can use one op to represent the model.
|
@pytorchbot rebase this please |
|
@houseroad merged this pull request in 76d262d. |
Summary: Updated group_norm symbolic Pull Request resolved: pytorch#27071 Reviewed By: hl475 Differential Revision: D17792249 Pulled By: houseroad fbshipit-source-id: 08be6071952ca2c256d2c6a0a6bbc19a8442f1fe
Updated group_norm symbolic