mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nccl symm mem] don't use arg for mempool, correctly use symmetric registration in hooks (#161238)
Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/161238 Approved by: https://github.com/kwen2501, https://github.com/syed-ahmed
This commit is contained in:
committed by
PyTorch MergeBot
parent
74280d0913
commit
726dce3c94
@ -3099,7 +3099,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
self._run_invalid_nccl_blocking_wait_env("4294967295")
|
||||
|
||||
|
||||
class NcclRegistrationTest(MultiProcessTestCase):
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
@ -3191,7 +3191,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
|
||||
# Use NCCL memory allocator
|
||||
# enable symmetric memory usage in NCCL
|
||||
pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True)
|
||||
pool = torch.cuda.MemPool(backend.mem_allocator)
|
||||
|
||||
# allocate memory with ncclMemAlloc
|
||||
# note: symmetric kernels are not available for dtypes like torch.int64
|
||||
@ -3201,10 +3201,16 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
)
|
||||
|
||||
# register buffers to NCCL
|
||||
backend.register_mem_pool(pool)
|
||||
backend.register_mem_pool(pool, symm=True)
|
||||
|
||||
# allreduce now should use NVIDIA Switches
|
||||
pg.allreduce(tensor).wait()
|
||||
# check that further allocations are also registered
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
tensor = torch.arange(
|
||||
1024 * 1024 * 2, device=device, dtype=torch.float32
|
||||
)
|
||||
pg.allreduce(tensor).wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
|
||||
# de-register buffers from NCCL
|
||||
@ -3217,7 +3223,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
nccl_debug_file_content = f.read()
|
||||
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
|
||||
# should show successful registration in debug output
|
||||
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
|
||||
self.assertRegex(nccl_debug_file_content, "Symmetric")
|
||||
|
||||
|
||||
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
|
Reference in New Issue
Block a user