mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
py 3.9 has spawn. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146705 Approved by: https://github.com/colesbury
61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
from torch.testing._internal.distributed.distributed_test import (
|
|
DistributedTest,
|
|
TestDistBackend,
|
|
)
|
|
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
_allowed_backends = ("gloo", "nccl", "ucc")
|
|
if (
|
|
"BACKEND" not in os.environ
|
|
or "WORLD_SIZE" not in os.environ
|
|
or "TEMP_DIR" not in os.environ
|
|
):
|
|
# TODO can we actually have `run_tests.py` emit the complete instructions when it prints a repro command?
|
|
raise RuntimeError(
|
|
"Missing expected env vars for `test_distributed_spawn.py`. Please ensure to specify the following:\n"
|
|
f"'BACKEND' = one of {_allowed_backends}\n"
|
|
f"'WORLD_SIZE' = int >= 2\n"
|
|
"'TEMP_DIR' specifying a directory containing a barrier file named 'barrier'.\n\n"
|
|
f"e.g.\ntouch /tmp/barrier && TEMP_DIR=/tmp BACKEND='nccl' WORLD_SIZE=2 python {__file__}",
|
|
)
|
|
|
|
BACKEND = os.environ["BACKEND"]
|
|
|
|
if BACKEND in _allowed_backends:
|
|
|
|
class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
|
|
|
|
else:
|
|
print(f"Invalid backend {BACKEND}. Tests will not be run!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|