add process_group in convert_sync_batchnorm (#19240)

Summary:
In line 508.
convert_sync_batchnorm is called recursively to convert the bn to syncbn, thus the process_group also should be passed in the function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19240

Differential Revision: D15240318

Pulled By: ezyang

fbshipit-source-id: 0fc9e856392824814991e5e9e8f9513d57f311af
This commit is contained in:
Zhang Liliang
2019-05-07 06:45:12 -07:00
committed by Facebook Github Bot
parent 2356fac9a5
commit f7a7868820
2 changed files with 23 additions and 1 deletions

View File

@ -20,6 +20,15 @@ from common_utils import TestCase, run_tests
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
BACKEND = os.environ["BACKEND"]
TEMP_DIR = os.environ["TEMP_DIR"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
@ -1528,6 +1537,19 @@ class _DistTestBase(object):
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
@skipIfNoTorchVision
def test_SyncBatchNorm_process_group(self):
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
# it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
# is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
process_ids = 0
process_group = torch.distributed.new_group([process_ids])
res50_model = torchvision.models.resnet50()
res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group)
process_group_sync = res50_model_sync.layer1[0].bn1.process_group
self.assertEqual(process_group_sync, process_group)
if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"]

View File

@ -505,6 +505,6 @@ class SyncBatchNorm(_BatchNorm):
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child))
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output