diff --git a/test/test_distributed.py b/test/test_distributed.py index 2665459e4191..e25da8591a73 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -20,6 +20,15 @@ from common_utils import TestCase, run_tests from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT +try: + import torchvision + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False + + +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + BACKEND = os.environ["BACKEND"] TEMP_DIR = os.environ["TEMP_DIR"] INIT_METHOD = os.getenv("INIT_METHOD", "env://") @@ -1528,6 +1537,19 @@ class _DistTestBase(object): gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda')) + @skipIfNoTorchVision + def test_SyncBatchNorm_process_group(self): + # When adopting `convert_sync_batchnorm` to convert a `nn.modules`, + # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm` + # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models). + + process_ids = 0 + process_group = torch.distributed.new_group([process_ids]) + res50_model = torchvision.models.resnet50() + res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group) + process_group_sync = res50_model_sync.layer1[0].bn1.process_group + self.assertEqual(process_group_sync, process_group) + if BACKEND == "gloo" or BACKEND == "nccl": WORLD_SIZE = os.environ["WORLD_SIZE"] diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index cf360bfe90d9..0428ffeeeac3 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -505,6 +505,6 @@ class SyncBatchNorm(_BatchNorm): module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked for name, child in module.named_children(): - module_output.add_module(name, cls.convert_sync_batchnorm(child)) + module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group)) del module return module_output