mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[c10d] Add more tests to prevent extra context (#154174)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Loop a bunch of sync ops and see if any of them creates extra context. Requires nvml to check number of processes resident on a device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154174 Approved by: https://github.com/atalman
This commit is contained in:
@ -684,6 +684,56 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
self._helper_test_extra_cuda_context_by_memory()
|
self._helper_test_extra_cuda_context_by_memory()
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_extra_cuda_context_sync_ops(self):
|
||||||
|
# Loop a bunch of sync ops and see if any of them creates extra context.
|
||||||
|
# Requires nvml to check number of processes resident on a device.
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
except Exception:
|
||||||
|
self.skipTest("pynvml not available")
|
||||||
|
|
||||||
|
# Check if non-0 ranks would create extra CUDA context on device 0
|
||||||
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
|
device = torch.device(f"cuda:{self.rank:d}")
|
||||||
|
c10d.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
store=store,
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
device_id=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = torch.empty((1,), device=device)
|
||||||
|
y = torch.empty((self.world_size,), device=device)
|
||||||
|
|
||||||
|
c10d.all_reduce(x)
|
||||||
|
c10d.reduce(x, dst=0)
|
||||||
|
c10d.broadcast(x, src=0)
|
||||||
|
c10d.all_gather_into_tensor(y, x)
|
||||||
|
c10d.reduce_scatter_tensor(x, y)
|
||||||
|
c10d.barrier()
|
||||||
|
|
||||||
|
# Wait a bit for remote processes to touch my device
|
||||||
|
if self.rank == 0:
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
|
||||||
|
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||||
|
nprocs = len(processes)
|
||||||
|
|
||||||
|
# Don't exit till rank 0 is done with the nvml detection
|
||||||
|
c10d.barrier()
|
||||||
|
c10d.destroy_process_group()
|
||||||
|
self.assertLessEqual(
|
||||||
|
nprocs,
|
||||||
|
1,
|
||||||
|
f"Found {nprocs} processes creating contexts on {device}, expecting 1 at most",
|
||||||
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
def test_destruct_before_terminate_pg(self):
|
def test_destruct_before_terminate_pg(self):
|
||||||
|
Reference in New Issue
Block a user