Skip to content

Conversation

@neginraoof
Copy link
Contributor

Updated group_norm symbolic

@pytorchbot pytorchbot added the module: onnx Related to torch.onnx label Sep 30, 2019
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
input_sizes = input.type().sizes()
shape = [1, input_sizes[0] * num_groups, input_sizes[1] / num_groups, -1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact it might be possible to remove the dependency on explicit input_sizes.

  • For the first reshape, reshape to [0, num_groups, -1], the effect will be [batch, channel, h, w] => [batch, num_groups, channel / num_groups * h * w].
  • For weight and bias, consider the below comment regarding broadcasting.
  • For the second reshape, just reshape back to the original input shape: g.op('Reshape', norm_reshaped, g.op("Shape", input)).
  • For the last weight and bias, it seems we still need the rank information of input. We can remove this dependency as well if we add two more reshapes to the norm tensor, but that might not worth the effort as rank info is usually static.

The motivation is that we won't be able to support dynamic input sizes if we depend on static input_sizes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BowenBao, apparently broadcasting is not supported for InstanceNormalization.
So I don't think we can avoid accessing the size for weight_ and bias_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but I think we can still construct weight_ and bias_ with just num_groups, if in the first reshape we keep batch and channel axes separated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I gonna update the first two parts anyways. But I agree that it might not be worth it to add two reshapes since rank here could be N (it's not limited in spec). I'll leave a note in case anyone needed this op with dynamic rank in future.

@neginraoof
Copy link
Contributor Author

@pytorchbot rebase this please

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please always remember to update test operator expect files if symbolic fn is updated.


@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
if input.isCompleteTensor():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isCompleteTensor can be relaxed. We only require its scalar type and rank to be known.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'm on console access to the environment and build is slow, so I updated the files here.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neginraoof
Copy link
Contributor Author

cc @houseroad tests are fixed. Please re-import. Thanks.

@neginraoof
Copy link
Contributor Author

cc @houseroad for review

1 similar comment
@neginraoof
Copy link
Contributor Author

cc @houseroad for review

@neginraoof
Copy link
Contributor Author

@houseroad please review updates.

@neginraoof
Copy link
Contributor Author

@pytorchbot retest this please

@neginraoof
Copy link
Contributor Author

cc @houseroad for review

@neginraoof
Copy link
Contributor Author

@houseroad for review

@neginraoof
Copy link
Contributor Author

@pytorchbot rebase this please

@neginraoof
Copy link
Contributor Author

@houseroad please review the updates

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad houseroad requested a review from BowenBao October 22, 2019 17:25
Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This broke some internal use cases, I need further investigation.

@neginraoof
Copy link
Contributor Author

Thanks @houseroad for looking into this. Is there a failure regarding group_norm?

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check in this for now. However, this translation is really inefficient. We should introduce GroupNorm to ONNX, so we can use one op to represent the model.

@houseroad
Copy link
Member

@pytorchbot rebase this please

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in 76d262d.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
Summary:
Updated group_norm symbolic
Pull Request resolved: pytorch#27071

Reviewed By: hl475

Differential Revision: D17792249

Pulled By: houseroad

fbshipit-source-id: 08be6071952ca2c256d2c6a0a6bbc19a8442f1fe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: onnx Related to torch.onnx open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants