[nccl symm mem] don't use arg for mempool, correctly use symmetric registration in hooks (#161238)

Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161238
Approved by: https://github.com/kwen2501, https://github.com/syed-ahmed
This commit is contained in:
Natalia Gimelshein
2025-08-25 03:09:32 +00:00
committed by PyTorch MergeBot
parent 74280d0913
commit 726dce3c94
9 changed files with 78 additions and 69 deletions

View File

@ -3099,7 +3099,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("4294967295")
class NcclRegistrationTest(MultiProcessTestCase):
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@ -3191,7 +3191,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
# Use NCCL memory allocator
# enable symmetric memory usage in NCCL
pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True)
pool = torch.cuda.MemPool(backend.mem_allocator)
# allocate memory with ncclMemAlloc
# note: symmetric kernels are not available for dtypes like torch.int64
@ -3201,10 +3201,16 @@ class NcclRegistrationTest(MultiProcessTestCase):
)
# register buffers to NCCL
backend.register_mem_pool(pool)
backend.register_mem_pool(pool, symm=True)
# allreduce now should use NVIDIA Switches
pg.allreduce(tensor).wait()
# check that further allocations are also registered
with torch.cuda.use_mem_pool(pool):
tensor = torch.arange(
1024 * 1024 * 2, device=device, dtype=torch.float32
)
pg.allreduce(tensor).wait()
torch.cuda.synchronize(device=device)
# de-register buffers from NCCL
@ -3217,7 +3223,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
nccl_debug_file_content = f.read()
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
# should show successful registration in debug output
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
self.assertRegex(nccl_debug_file_content, "Symmetric")
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):