Skip to content

Conversation

@zhangliliang
Copy link
Contributor

@zhangliliang zhangliliang commented Apr 13, 2019

In line 508.
convert_sync_batchnorm is called recursively to convert the bn to syncbn, thus the process_group also should be passed in the function.

In line 508.
convert_sync_batchnorm is called recursively to convert the bn to syncbn, thus the process_group also should be pass in the function.
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

@ezyang
Copy link
Contributor

ezyang commented Apr 15, 2019

Thank you. Is there an easy way to test this case?

@zhangliliang
Copy link
Contributor Author

zhangliliang commented Apr 17, 2019

Thank you. Is there an easy way to test this case?

In my opinion, process_group is used to restrict the group when synchronizing means and variances in the batchnorm. When a user adopts convert_sync_batchnorm to convert a nn.modules, he might expert the function would set all the SyncBatchNorm in the model to restrict their synchorized group.

Thus, it might to need to recursively pass the process_group in the module when the SyncBatchNorm is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).

A testcase might be use this fucntion to convert the bn to syncbn in resnet50, and then checkout whether all the 'syncbn' have been assigned a proper value of 'progress_group' .

@ezyang
Copy link
Contributor

ezyang commented Apr 19, 2019

Thanks. Do you think you would have time to add this test?

@ezyang ezyang self-requested a review April 19, 2019 15:43
@zhangliliang
Copy link
Contributor Author

Thanks. Do you think you would have time to add this test?

Thanks for replying.

Do you mean that adding a testcase which like class PackedSequenceTest(TestCase) in test/test_nn.py.

If that so, I might try it on.

@ssnl
Copy link
Collaborator

ssnl commented Apr 20, 2019

@zhangliliang It would be adding a test method in test_distributed.py

@zhangliliang
Copy link
Contributor Author

zhangliliang commented Apr 21, 2019

@zhangliliang It would be adding a test method in test_distributed.py

@ssnl Get it. I would try.

@zhangliliang
Copy link
Contributor Author

I write an example to test the case.
It would exit while running the second assert, since the group_processing is not set correctly in the nn.SyncBatchNorm.convert_sync_batchnorm.

Do you consider whether it is a right testcase? @ssnl @ezyang
If so, I would try to re-organize it into test_distributed.py.

Thanks.

import torch
from torch import nn
import torch.distributed as dist
import torchvision.models as models
import copy


def convert_sync_batchnorm_fixed(module, process_group=None):

    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module_output = torch.nn.SyncBatchNorm(module.num_features,
                                               module.eps, module.momentum,
                                               module.affine,
                                               module.track_running_stats,
                                               process_group)
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, convert_sync_batchnorm_fixed(child, process_group))
    del module
    return module_output


world_size = 1
process_ids = 0

dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:1234",
                                world_size=world_size, rank=process_ids)

process_group = torch.distributed.new_group([process_ids])
res50_model = models.resnet50()
res50_model_sync_ori = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group)
res50_model_sync_fixed = convert_sync_batchnorm_fixed(copy.deepcopy(res50_model), process_group)

process_group_sync_ori = res50_model_sync_ori.layer1[0].bn1.process_group
process_group_sync_fixed = res50_model_sync_fixed.layer1[0].bn1.process_group


assert(process_group_sync_fixed == process_group)

assert(process_group_sync_ori == process_group)

@ezyang
Copy link
Contributor

ezyang commented Apr 22, 2019

The test looks good to me. If you add a little explanation, like what you wrote in your comment, to the code, that would be perfect. 👍

@ssnl
Copy link
Collaborator

ssnl commented Apr 22, 2019

@ezyang Do we always have torchvision install in CI? If not, the test should probably use a custom network.

@ezyang
Copy link
Contributor

ezyang commented Apr 22, 2019

Yeah, we install torchvision, and there are existing tests which use it. Actually, this is probably changing soon cc @fmassa, but for now it shouldn't be a problem.

@ssnl
Copy link
Collaborator

ssnl commented May 6, 2019

oh no we didn't include this 1.1!

@ssnl
Copy link
Collaborator

ssnl commented May 6, 2019

@pytorchbot rebase this please

@ssnl
Copy link
Collaborator

ssnl commented May 6, 2019

@pytorchbot merge this please

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label May 6, 2019
@zhangliliang
Copy link
Contributor Author

zhangliliang commented May 6, 2019

Sorry for replying later. These days I occupied by some stuff.
I add a test case for this PR and test it in my computer.
Please check it out whether the code is wrote appropriately.

@ssnl
Copy link
Collaborator

ssnl commented May 6, 2019

@zhangliliang No worries. It's not your fault. Thanks for your contribution!

@zhangliliang
Copy link
Contributor Author

@zhangliliang No worries. It's not your fault. Thanks for your contribution!

Thanks.~

@zhangliliang
Copy link
Contributor Author

zhangliliang commented May 6, 2019

@ssnl
It seems that some checks were fail, since the torchvision is not sucessfully imported.
Do I need to handle it by removing the dependency of torchvision in the testcase?

@ezyang
Copy link
Contributor

ezyang commented May 6, 2019

I haven't looked at the PR, but we have tests which are conditionally enabled depending on if torchvision is available; go take a look at them and copy the pattern.

@ezyang
Copy link
Contributor

ezyang commented May 6, 2019

It looks like there were more changes since the merge request, I will wait for tests.

@zhangliliang
Copy link
Contributor Author

It looks like there were more changes since the merge request, I will wait for tests.

@ezyang
It seems that one error (binary_macos_libtorch_2.7_cpu_build) still exist, but I don't know how to deal with it. Could you give some ideas?

@zhangliliang
Copy link
Contributor Author

All checks have passed, now.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in f7a7868.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-this-please Was marked for merge with @pytorchbot merge this please open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants