mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add process_group in convert_sync_batchnorm (#19240)
Summary: In line 508. convert_sync_batchnorm is called recursively to convert the bn to syncbn, thus the process_group also should be passed in the function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19240 Differential Revision: D15240318 Pulled By: ezyang fbshipit-source-id: 0fc9e856392824814991e5e9e8f9513d57f311af
This commit is contained in:
committed by
Facebook Github Bot
parent
2356fac9a5
commit
f7a7868820
@ -20,6 +20,15 @@ from common_utils import TestCase, run_tests
|
||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
|
||||
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
BACKEND = os.environ["BACKEND"]
|
||||
TEMP_DIR = os.environ["TEMP_DIR"]
|
||||
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
|
||||
@ -1528,6 +1537,19 @@ class _DistTestBase(object):
|
||||
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
|
||||
self._test_DistributedDataParallel_SyncBatchNorm(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_SyncBatchNorm_process_group(self):
|
||||
# When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
|
||||
# it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
|
||||
# is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).
|
||||
|
||||
process_ids = 0
|
||||
process_group = torch.distributed.new_group([process_ids])
|
||||
res50_model = torchvision.models.resnet50()
|
||||
res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(res50_model), process_group)
|
||||
process_group_sync = res50_model_sync.layer1[0].bn1.process_group
|
||||
self.assertEqual(process_group_sync, process_group)
|
||||
|
||||
if BACKEND == "gloo" or BACKEND == "nccl":
|
||||
WORLD_SIZE = os.environ["WORLD_SIZE"]
|
||||
|
||||
|
@ -505,6 +505,6 @@ class SyncBatchNorm(_BatchNorm):
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, cls.convert_sync_batchnorm(child))
|
||||
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
|
||||
del module
|
||||
return module_output
|
||||
|
Reference in New Issue
Block a user