Skip to content

Conversation

@lara-hdr
Copy link
Contributor

No description provided.

@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Jun 26, 2019
@ailzhang ailzhang requested a review from houseroad June 27, 2019 04:06
@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 27, 2019

layer_norm = div(g, sub(g, input, mean), sqrt(g, add(g, variance, eps_cst)))

if not (bias is None or bias.node().mustBeNone()) and \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we handle if either of them is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will split this into 2 if statements.

mean = g.op("ReduceMean", input, axes_i=axes)
squared_input = pow(g, input, two_cst)
squared_mean = pow(g, mean, two_cst)
squared_input_mean = g.op("ReduceMean", squared_input, axes_i=axes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Could this variance computation be done using ONNX::std? If yes, is that preferable in anyway to this subgraph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, var=std^2.

@zhangguanheng66
Copy link
Contributor

@spandantiwari @lara-hdr Oncall today. Let me know if the PR is ready to land. I can help with landing. Thanks.

@lara-hdr
Copy link
Contributor Author

lara-hdr commented Jul 1, 2019

@zhangguanheng66 yes, I will ping you when the CI checks complete.
Thanks!



@parse_args('v', 'is', 'v', 'v', 'f', 'i')
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a check here, if export mode is ONNX_ATEN_FALLBACK, we still export layer norm as ATen operator. Otherwise, the new export logic will significantly degrade the performance.

Reference:

operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a function _set_operator_export_type (similar to _set_opset_version) in symbolic_helper.py, to save the operator_export_type and access it here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
self.assertONNX(model, x,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: one more empty line?

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. One more format nit

@lara-hdr
Copy link
Contributor Author

lara-hdr commented Jul 1, 2019

@houseroad nit fixed

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in 7ca7edc.

xzhu1900 pushed a commit to xzhu1900/pytorch that referenced this pull request Jul 5, 2019
Summary: Pull Request resolved: pytorch#22265

Reviewed By: zrphercule

Differential Revision: D16076268

Pulled By: houseroad

fbshipit-source-id: 29b4ecab2fa0dc7250c9d1ad6924903181a66ab2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: onnx Related to torch.onnx open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants