[c10d] Add more tests to prevent extra context (#154174)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

Loop a bunch of sync ops and see if any of them creates extra context.
Requires nvml to check number of processes resident on a device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154174
Approved by: https://github.com/atalman
This commit is contained in:
Ke Wen
2025-05-22 17:54:44 -07:00
committed by PyTorch MergeBot
parent ba5d45d22e
commit 25149cd173

View File

@ -684,6 +684,56 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
except ModuleNotFoundError:
self._helper_test_extra_cuda_context_by_memory()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_extra_cuda_context_sync_ops(self):
# Loop a bunch of sync ops and see if any of them creates extra context.
# Requires nvml to check number of processes resident on a device.
try:
import pynvml
pynvml.nvmlInit()
except Exception:
self.skipTest("pynvml not available")
# Check if non-0 ranks would create extra CUDA context on device 0
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank:d}")
c10d.init_process_group(
backend="nccl",
store=store,
rank=self.rank,
world_size=self.world_size,
device_id=device,
)
x = torch.empty((1,), device=device)
y = torch.empty((self.world_size,), device=device)
c10d.all_reduce(x)
c10d.reduce(x, dst=0)
c10d.broadcast(x, src=0)
c10d.all_gather_into_tensor(y, x)
c10d.reduce_scatter_tensor(x, y)
c10d.barrier()
# Wait a bit for remote processes to touch my device
if self.rank == 0:
time.sleep(5)
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
nprocs = len(processes)
# Don't exit till rank 0 is done with the nvml detection
c10d.barrier()
c10d.destroy_process_group()
self.assertLessEqual(
nprocs,
1,
f"Found {nprocs} processes creating contexts on {device}, expecting 1 at most",
)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_destruct_before_terminate_pg(self):