Enables NCCL symmetric memory kernels through mempool registration (#155134)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155134
Approved by: https://github.com/kwen2501

Co-authored-by: Ke Wen <kw2501@meta.com>
This commit is contained in:
Syed Tousif Ahmed
2025-06-21 13:08:13 -07:00
committed by PyTorch MergeBot
parent 9e132b770e
commit f70c80105e
9 changed files with 146 additions and 15 deletions

View File

@ -3084,7 +3084,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("4294967295")
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
class NcclRegistrationTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@ -3095,7 +3095,7 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS"
if torch.cuda.nccl.version() >= (2, 24, 3):
os.environ["NCCL_DEBUG_SUBSYS"] = "REG"
os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
self._spawn_processes()
@ -3151,6 +3151,48 @@ class NcclUserBufferRegistrationTest(MultiProcessTestCase):
else:
self.assertRegex(nccl_debug_file_content, "local-registered")
@requires_nccl()
@requires_nccl_version((2, 27), "Need NCCL 2.27 for window registration")
@skip_if_lt_x_gpu(4)
@requires_multicast_support()
def test_nccl_window_registration(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
)
device = torch.device(f"cuda:{self.rank}")
torch.cuda.set_device(self.rank)
pg = c10d.distributed_c10d._get_default_group()
backend = pg._get_backend(torch.device(device))
# Use NCCL memory allocator
# enable symmetric memory usage in NCCL
pool = torch.cuda.MemPool(backend.mem_allocator, symm_mem=True)
# allocate memory with ncclMemAlloc
# note: symmetric kernels are not available for dtypes like torch.int64
with torch.cuda.use_mem_pool(pool):
tensor = torch.arange(1024 * 1024 * 2, device=device, dtype=torch.float32)
# register buffers to NCCL
backend.register_mem_pool(pool)
# allreduce now should use NVIDIA Switches
pg.allreduce(tensor).wait()
torch.cuda.synchronize(device=device)
# de-register buffers from NCCL
backend.deregister_mem_pool(pool)
# clean up memory
del tensor, pool
with open(os.environ["NCCL_DEBUG_FILE"]) as f:
nccl_debug_file_content = f.read()
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
# should show successful registration in debug output
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@property