-
Notifications
You must be signed in to change notification settings - Fork 26.3k
convert_sync_batch_norm to SyncBatchNorm #18787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are other call sites of convert_sync_batchnorm that you should fix https://github.com/pytorch/pytorch/search?q=convert_sync_batchnorm&unscoped_q=convert_sync_batchnorm
torch/nn/modules/_functions.py
Outdated
|
|
||
| return grad_input, grad_weight, grad_bias, None, None, None, None, None, None | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a class method. Line 128 should use cls instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will Fix that.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
next time you submit/update a PR for review. please make sure that the relevant tests pass locally first.
torch/nn/modules/_functions.py
Outdated
| module_output.running_var = module.running_var | ||
| module_output.num_batches_tracked = module.num_batches_tracked | ||
| for name, child in module.named_children(): | ||
| module_output.add_module(name, convert_sync_batchnorm(child)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This recursive call needs to be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. Totally missed that. Also, to make sure that the relevant tests pass, I just need to use the relevant test files in the tests directory right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iArunava yes, just run python test_torch.py etc and it should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made the changes required, also all the places here: https://github.com/pytorch/pytorch/search?q=convert_sync_batchnorm&unscoped_q=convert_sync_batchnorm
torch.nn.convert_sync_batch_norm is replaced by torch.nn.SyncBatchNorm.convert_sync_batch_norm as its a @classmethod Is should be accessible in that way. Is that okay?
And before pushing I want to run the tests, but I am not able to run the tests, I keep getting cannot import module istuple from the torch._six.py file, I can't find the reason why, everything else from _six.py is imported okay. Can you please help @fmassa ? I am quite sure I am missing something here.
|
I pushed without testing sorry for that. But I was unable to get the tests working. I expected it will work. I tried uninstalling torch and trying to get into develop mode as said in |
|
It's okay. For a small change like this, you can use the CI to fix things too. I'll tag this as WIP. When you are done and this is ready for review. Feel free to remove "WIP" and/or tag us :) For the long term and bigger changes, it would still be helpful to figure out local installation. |
torch/nn/modules/_functions.py
Outdated
| return grad_input, grad_weight, grad_bias, None, None, None, None, None, None | ||
|
|
||
| @classmethod | ||
| def convert_sync_batchnorm(cls, module, process_group=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey you should add this on the module rather than the autograd function!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey if I get you correctly, then you are asking to move the convert_sync_batchnorm function to in the SyncBatchNorm function here:
pytorch/torch/nn/modules/batchnorm.py
Line 322 in 0829ef0
| class SyncBatchNorm(_BatchNorm): |
|
@ssnl can you please check once, I have finally built pytorch on my machine. And I have changed the EDIT: I dont have cuda so, the cuda tests were skipped. |
|
@pytorchbot rebase this please |
|
@ssnl Please check! :) |
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good now. thank you!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Closes pytorch#18382 Please let me know if any changes are required. Pull Request resolved: pytorch#18787 Differential Revision: D14821147 Pulled By: soumith fbshipit-source-id: edd98eab1b3f4151c4ae5148146435ddb2ae678d
Closes #18382
Please let me know if any changes are required.