mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Fix test test_nccl_user_buffer_registration (#160497)
Fixed `test_nccl_user_buffer_registration ` due to https://github.com/pytorch/pytorch/pull/160145, somehow CI didn't capture it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160497 Approved by: https://github.com/ngimel
This commit is contained in:
@ -3127,10 +3127,14 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
@requires_multicast_support()
|
||||
def test_nccl_user_buffer_registration(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
backend="nccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=store,
|
||||
device_id=device,
|
||||
)
|
||||
torch.cuda.set_device(self.rank)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
backend = pg._get_backend(torch.device(device))
|
||||
|
Reference in New Issue
Block a user