Skip to content

add params check when convert BatchNorm to GroupNorm#390

Closed
XiaobingSuper wants to merge 2 commits intometa-pytorch:mainfrom
XiaobingSuper:xiaobing/bn_conversion_check
Closed

add params check when convert BatchNorm to GroupNorm#390
XiaobingSuper wants to merge 2 commits intometa-pytorch:mainfrom
XiaobingSuper:xiaobing/bn_conversion_check

Conversation

@XiaobingSuper
Copy link
Copy Markdown
Contributor

This PR will do params check when converting BatchNorm to GroupNorm, because GroupNorm should have some pre-request at its initiation step.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 21, 2022
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@JohnlNguyen has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@JohnlNguyen
Copy link
Copy Markdown

@XiaobingSuper
Copy link
Copy Markdown
Contributor Author

@JohnlNguyen , the failed case try to convert BatchNorm2d(72, eps=0.001, momentum=0.01, affine=True, track_running_stats=True) to GroupNorm, but it is invalid,how do we ignore this conversion?

Comment on lines +88 to +91
if module.num_features % min(32, module.num_features) != 0:
raise UnsupportableModuleError(
"There is no equivalent GroupNorm module to replace BatchNorm with."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if module.num_features % min(32, module.num_features) != 0:
raise UnsupportableModuleError(
"There is no equivalent GroupNorm module to replace BatchNorm with."
)

Raising an error seems very prohibitive. The default value of 32 was chosen based on the empirical results in the paper, but using a different value for num_groups is still better than disallowing conversion.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the previous behavior? You would get an error message if and only if you actually ran the module?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan, yes, you only get an error when running the module before.

"There is no equivalent GroupNorm module to replace BatchNorm with."
)
return nn.GroupNorm(
min(32, module.num_features), module.num_features, affine=module.affine
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
min(32, module.num_features), module.num_features, affine=module.affine
gcd(32, module.num_features), module.num_features, affine=module.affine

How about replacing min with gcd? This should work for any number of channels, and will distill to InstanceNorm in the extreme case.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the code is changed to use the gcd method.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@XiaobingSuper has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@karthikprasad has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants