Skip to content

Conversation

@iArunava
Copy link
Contributor

@iArunava iArunava commented Apr 3, 2019

Closes #18382

Please let me know if any changes are required.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return grad_input, grad_weight, grad_bias, None, None, None, None, None, None

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a class method. Line 128 should use cls instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will Fix that.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next time you submit/update a PR for review. please make sure that the relevant tests pass locally first.

module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This recursive call needs to be fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Totally missed that. Also, to make sure that the relevant tests pass, I just need to use the relevant test files in the tests directory right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iArunava yes, just run python test_torch.py etc and it should be fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the changes required, also all the places here: https://github.com/pytorch/pytorch/search?q=convert_sync_batchnorm&unscoped_q=convert_sync_batchnorm

torch.nn.convert_sync_batch_norm is replaced by torch.nn.SyncBatchNorm.convert_sync_batch_norm as its a @classmethod Is should be accessible in that way. Is that okay?

And before pushing I want to run the tests, but I am not able to run the tests, I keep getting cannot import module istuple from the torch._six.py file, I can't find the reason why, everything else from _six.py is imported okay. Can you please help @fmassa ? I am quite sure I am missing something here.

@iArunava
Copy link
Contributor Author

iArunava commented Apr 3, 2019

I pushed without testing sorry for that. But I was unable to get the tests working. I expected it will work. I tried uninstalling torch and trying to get into develop mode as said in CONTRIBUTING.md but the build kept failing. Can you please help me as how to get the develop mode ? Thanks! I will fix this ASAP.

@ssnl
Copy link
Collaborator

ssnl commented Apr 3, 2019

It's okay. For a small change like this, you can use the CI to fix things too. I'll tag this as WIP. When you are done and this is ready for review. Feel free to remove "WIP" and/or tag us :)

For the long term and bigger changes, it would still be helpful to figure out local installation.

@ssnl ssnl changed the title convert_sync_batch_norm to SyncBatchNorm [wip] convert_sync_batch_norm to SyncBatchNorm Apr 3, 2019
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None

@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey you should add this on the module rather than the autograd function!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey if I get you correctly, then you are asking to move the convert_sync_batchnorm function to in the SyncBatchNorm function here:

class SyncBatchNorm(_BatchNorm):
Please correct me if I am wrong.

@iArunava
Copy link
Contributor Author

iArunava commented Apr 6, 2019

@ssnl can you please check once, I have finally built pytorch on my machine. And I have changed the convert_sync_batch_norm from the autograd to the module and have checked it and it executed successfully. nn.SyncBatchNorm.convert_sync_batch_norm(module) works.
Also, I ran the python test/run_test.py it shows everything is okay.

EDIT: I dont have cuda so, the cuda tests were skipped.

@ssnl
Copy link
Collaborator

ssnl commented Apr 6, 2019

@pytorchbot rebase this please

@iArunava
Copy link
Contributor Author

iArunava commented Apr 7, 2019

@ssnl Please check! :)

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good now. thank you!

@soumith soumith changed the title [wip] convert_sync_batch_norm to SyncBatchNorm convert_sync_batch_norm to SyncBatchNorm Apr 7, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@iArunava iArunava deleted the sync_batch_norm branch April 7, 2019 07:21
@facebook-github-bot
Copy link
Contributor

@soumith merged this pull request in 79533ef.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
Closes pytorch#18382

Please let me know if any changes are required.
Pull Request resolved: pytorch#18787

Differential Revision: D14821147

Pulled By: soumith

fbshipit-source-id: edd98eab1b3f4151c4ae5148146435ddb2ae678d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FR] make torch.nn.utils.convert_sync_batchnorm a classmethod of SyncBatchNorm

6 participants