mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Add more tests to prevent extra context (#154174)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Loop a bunch of sync ops and see if any of them creates extra context. Requires nvml to check number of processes resident on a device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154174 Approved by: https://github.com/atalman
This commit is contained in:
@ -684,6 +684,56 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
except ModuleNotFoundError:
|
||||
self._helper_test_extra_cuda_context_by_memory()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_extra_cuda_context_sync_ops(self):
|
||||
# Loop a bunch of sync ops and see if any of them creates extra context.
|
||||
# Requires nvml to check number of processes resident on a device.
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
except Exception:
|
||||
self.skipTest("pynvml not available")
|
||||
|
||||
# Check if non-0 ranks would create extra CUDA context on device 0
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
c10d.init_process_group(
|
||||
backend="nccl",
|
||||
store=store,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
device_id=device,
|
||||
)
|
||||
|
||||
x = torch.empty((1,), device=device)
|
||||
y = torch.empty((self.world_size,), device=device)
|
||||
|
||||
c10d.all_reduce(x)
|
||||
c10d.reduce(x, dst=0)
|
||||
c10d.broadcast(x, src=0)
|
||||
c10d.all_gather_into_tensor(y, x)
|
||||
c10d.reduce_scatter_tensor(x, y)
|
||||
c10d.barrier()
|
||||
|
||||
# Wait a bit for remote processes to touch my device
|
||||
if self.rank == 0:
|
||||
time.sleep(5)
|
||||
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
|
||||
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||
nprocs = len(processes)
|
||||
|
||||
# Don't exit till rank 0 is done with the nvml detection
|
||||
c10d.barrier()
|
||||
c10d.destroy_process_group()
|
||||
self.assertLessEqual(
|
||||
nprocs,
|
||||
1,
|
||||
f"Found {nprocs} processes creating contexts on {device}, expecting 1 at most",
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_destruct_before_terminate_pg(self):
|
||||
|
Reference in New Issue
Block a user