mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-17 16:46:31 +08:00
Rebase and re-land thread PG (#88795)
The previous PR (https://github.com/pytorch/pytorch/pull/88627) has been reverted due to a failed check. After rebasing and rerun, all checks passed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88795 Approved by: https://github.com/huydhn, https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
35093fc1ab
commit
ee05f47bdd
@ -2,10 +2,10 @@ import faulthandler
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
@ -14,11 +14,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import (
|
||||
partial,
|
||||
reduce,
|
||||
wraps
|
||||
)
|
||||
from functools import partial, reduce, wraps
|
||||
from io import StringIO
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
@ -26,16 +22,17 @@ import torch
|
||||
import torch.cuda.nccl
|
||||
import torch.distributed as c10d
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TSAN,
|
||||
FILE_SCHEMA,
|
||||
find_free_port,
|
||||
retry_on_connect_failures,
|
||||
IS_SANDCASTLE,
|
||||
sandcastle_skip_if,
|
||||
retry_on_connect_failures,
|
||||
sandcastle_skip,
|
||||
sandcastle_skip_if,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TSAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.multi_threaded_pg import run_with_threaded_pg
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -67,11 +64,10 @@ TEST_SKIPS = {
|
||||
"generic": TestSkip(
|
||||
86, "Test skipped at subprocess level, look at subprocess log for skip reason"
|
||||
),
|
||||
"importerror": TestSkip(
|
||||
88, "Test skipped due to missing import"
|
||||
),
|
||||
"importerror": TestSkip(88, "Test skipped due to missing import"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistTestCases:
|
||||
# Backends that do not support a specific collective
|
||||
@ -93,6 +89,7 @@ class DistTestCases:
|
||||
def skip_if_no_gpu(func):
|
||||
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
|
||||
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
@ -116,6 +113,7 @@ def skip_if_small_worldsize(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_odd_worldsize(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -126,6 +124,7 @@ def skip_if_odd_worldsize(func):
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_n_gpus_for_nccl_backend(n, backend):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@ -139,12 +138,17 @@ def require_n_gpus_for_nccl_backend(n, backend):
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def import_transformers_or_skip():
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
from transformers import BertConfig, AutoModelForMaskedLM # noqa: Unused
|
||||
from transformers import ( # noqa: Unused
|
||||
AutoModelForMaskedLM,
|
||||
BertConfig,
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
except ImportError:
|
||||
sys.exit(TEST_SKIPS["importerror"].exit_code)
|
||||
@ -153,6 +157,7 @@ def import_transformers_or_skip():
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def skip_if_lt_x_gpu(x):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@ -191,10 +196,13 @@ def verify_ddp_error_logged(model_DDP, err_substr):
|
||||
logging_err = ddp_logging_data["error"]
|
||||
# Remove C++ stacktrace if needed.
|
||||
actual = (
|
||||
err_substr if err_substr.find("\nException raised from ") == -1
|
||||
err_substr
|
||||
if err_substr.find("\nException raised from ") == -1
|
||||
else err_substr.split("\nException raised from ")[0]
|
||||
)
|
||||
assert actual in logging_err, f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
||||
assert (
|
||||
actual in logging_err
|
||||
), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
|
||||
|
||||
|
||||
def with_nccl_blocking_wait(func):
|
||||
@ -319,7 +327,7 @@ def skip_if_rocm(func):
|
||||
|
||||
def skip_if_win32():
|
||||
return sandcastle_skip_if(
|
||||
sys.platform == 'win32',
|
||||
sys.platform == "win32",
|
||||
"This unit test case is not supportted on Windows platform",
|
||||
)
|
||||
|
||||
@ -352,13 +360,14 @@ if TEST_WITH_TSAN:
|
||||
# TSAN runs much slower.
|
||||
TIMEOUT_DEFAULT = 500
|
||||
else:
|
||||
TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300'))
|
||||
TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
|
||||
TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/75665
|
||||
if TEST_WITH_ROCM:
|
||||
TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
|
||||
|
||||
|
||||
def create_device(interface=None):
|
||||
if sys.platform == "win32" or interface is None:
|
||||
return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1")
|
||||
@ -449,9 +458,7 @@ def init_multigpu_helper(world_size: int, backend: str):
|
||||
if world_size > nGPUs:
|
||||
nGPUs_per_process = nGPUs // world_size
|
||||
rank_to_GPU = {
|
||||
i: list(
|
||||
visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]
|
||||
)
|
||||
i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
|
||||
for i in range(world_size)
|
||||
}
|
||||
return rank_to_GPU
|
||||
@ -482,6 +489,9 @@ def cleanup_temp_dir() -> None:
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
# Most tests operate with this worldsize
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
# [How does MultiProcessTestCase work?]
|
||||
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
|
||||
# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
|
||||
@ -508,7 +518,7 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
return DEFAULT_WORLD_SIZE
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
@ -607,7 +617,10 @@ class MultiProcessTestCase(TestCase):
|
||||
@classmethod
|
||||
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
|
||||
# Enable DDP + ReplicatedTensor
|
||||
from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
|
||||
from torch.nn.parallel._replicated_tensor_ddp_utils import (
|
||||
_set_ddp_with_replicated_tensor,
|
||||
)
|
||||
|
||||
_set_ddp_with_replicated_tensor(True)
|
||||
|
||||
self = cls(test_name)
|
||||
@ -815,16 +828,20 @@ class MultiProcessTestCase(TestCase):
|
||||
self.assertEqual(
|
||||
first_process.exitcode,
|
||||
0,
|
||||
msg="Expected zero exit code but got {} for pid: {}".format(first_process.exitcode, first_process.pid)
|
||||
msg="Expected zero exit code but got {} for pid: {}".format(
|
||||
first_process.exitcode, first_process.pid
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_master(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
|
||||
# Cannot use functools.cache as it requires python 3.9
|
||||
EFA_PROBE_RESULT = None
|
||||
|
||||
|
||||
def has_efa() -> bool:
|
||||
"""
|
||||
If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
|
||||
@ -836,7 +853,9 @@ def has_efa() -> bool:
|
||||
return EFA_PROBE_RESULT
|
||||
|
||||
try:
|
||||
EFA_PROBE_RESULT = subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0
|
||||
EFA_PROBE_RESULT = (
|
||||
subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0
|
||||
)
|
||||
except FileNotFoundError:
|
||||
EFA_PROBE_RESULT = False
|
||||
return EFA_PROBE_RESULT
|
||||
@ -850,3 +869,81 @@ def tp_transports():
|
||||
see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
|
||||
"""
|
||||
return ["shm", "uv"] if has_efa() else None
|
||||
|
||||
|
||||
def _run_test_with_mt_pg(self, timeout, world_size, callback):
|
||||
failed_ranks = run_with_threaded_pg(world_size, timeout, callback)
|
||||
for rank, exc_info in failed_ranks:
|
||||
print(f"Rank {rank} raised:")
|
||||
for line in traceback.format_exception(*exc_info):
|
||||
sys.stdout.write(line)
|
||||
self.assertEqual([], failed_ranks, "Some ranks failed")
|
||||
|
||||
|
||||
def spawn_threads_and_init_comms(
|
||||
func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
|
||||
):
|
||||
"""
|
||||
Wrapper to use with a test method
|
||||
"""
|
||||
if func is None:
|
||||
return partial(
|
||||
spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
_run_test_with_mt_pg(
|
||||
self, timeout, world_size, lambda: func(self, *args, **kwargs)
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MultiThreadedTestCase(TestCase):
|
||||
"""
|
||||
Simple test runner that executes all tests with the in-proc process group.
|
||||
|
||||
A single instance of the TestCase object for all threads.
|
||||
|
||||
Difference from regular test runner:
|
||||
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
|
||||
Not sure what these two would be good for though.
|
||||
No global state possible
|
||||
How bad of a limitation is this?
|
||||
"""
|
||||
|
||||
def __init__(self, method_name: str = "runTest") -> None:
|
||||
super().__init__(method_name)
|
||||
self._test_method = getattr(self, method_name, None)
|
||||
setattr(self, method_name, self.threaded_run_test)
|
||||
if TestCase.setUp != type(self).setUp:
|
||||
raise RuntimeError(
|
||||
f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead"
|
||||
)
|
||||
if TestCase.tearDown != type(self).tearDown:
|
||||
raise RuntimeError(
|
||||
f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead"
|
||||
)
|
||||
|
||||
def threaded_run_test(self):
|
||||
self.perThreadSetUp()
|
||||
try:
|
||||
_run_test_with_mt_pg(
|
||||
self=self,
|
||||
timeout=TIMEOUT_DEFAULT,
|
||||
world_size=self.world_size,
|
||||
callback=self._test_method,
|
||||
)
|
||||
finally:
|
||||
self.perThreadTearDown()
|
||||
|
||||
def perThreadSetUp(self):
|
||||
pass
|
||||
|
||||
def perThreadTearDown(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
raise RuntimeError("world size not implemented")
|
||||
|
||||
Reference in New Issue
Block a user