-
Notifications
You must be signed in to change notification settings - Fork 26.3k
ONNX Export LayerNorm #22265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Export LayerNorm #22265
Conversation
torch/onnx/symbolic_opset9.py
Outdated
|
|
||
| layer_norm = div(g, sub(g, input, mean), sqrt(g, add(g, variance, eps_cst))) | ||
|
|
||
| if not (bias is None or bias.node().mustBeNone()) and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should we handle if either of them is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will split this into 2 if statements.
torch/onnx/symbolic_opset9.py
Outdated
| mean = g.op("ReduceMean", input, axes_i=axes) | ||
| squared_input = pow(g, input, two_cst) | ||
| squared_mean = pow(g, mean, two_cst) | ||
| squared_input_mean = g.op("ReduceMean", squared_input, axes_i=axes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Could this variance computation be done using ONNX::std? If yes, is that preferable in anyway to this subgraph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, var=std^2.
|
@spandantiwari @lara-hdr Oncall today. Let me know if the PR is ready to land. I can help with landing. Thanks. |
|
@zhangguanheng66 yes, I will ping you when the CI checks complete. |
|
|
||
|
|
||
| @parse_args('v', 'is', 'v', 'v', 'f', 'i') | ||
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a check here, if export mode is ONNX_ATEN_FALLBACK, we still export layer norm as ATen operator. Otherwise, the new export logic will significantly degrade the performance.
Reference:
| operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a function _set_operator_export_type (similar to _set_opset_version) in symbolic_helper.py, to save the operator_export_type and access it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| model = torch.nn.LayerNorm([10, 10]) | ||
| x = torch.randn(20, 5, 10, 10) | ||
| self.assertONNX(model, x, | ||
| operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: one more empty line?
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. One more format nit
|
@houseroad nit fixed |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@houseroad merged this pull request in 7ca7edc. |
Summary: Pull Request resolved: pytorch#22265 Reviewed By: zrphercule Differential Revision: D16076268 Pulled By: houseroad fbshipit-source-id: 29b4ecab2fa0dc7250c9d1ad6924903181a66ab2
No description provided.