Rebase and re-land thread PG (#88795)

The previous PR (https://github.com/pytorch/pytorch/pull/88627) has been reverted due to a failed check. After rebasing and rerun, all checks passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88795
Approved by: https://github.com/huydhn, https://github.com/wanchaol
This commit is contained in:
Charlie Yan
2022-11-15 18:03:53 +00:00
committed by PyTorch MergeBot
parent 35093fc1ab
commit ee05f47bdd
4 changed files with 457 additions and 26 deletions

View File

@ -2,10 +2,10 @@ import faulthandler
import logging
import multiprocessing
import os
import subprocess
import sys
import tempfile
import threading
import subprocess
import time
import traceback
import types
@ -14,11 +14,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from functools import (
partial,
reduce,
wraps
)
from functools import partial, reduce, wraps
from io import StringIO
from typing import NamedTuple, Optional, Union
@ -26,16 +22,17 @@ import torch
import torch.cuda.nccl
import torch.distributed as c10d
from torch.testing._internal.common_utils import (
TestCase,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
FILE_SCHEMA,
find_free_port,
retry_on_connect_failures,
IS_SANDCASTLE,
sandcastle_skip_if,
retry_on_connect_failures,
sandcastle_skip,
sandcastle_skip_if,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
TestCase,
)
from torch.testing._internal.distributed.multi_threaded_pg import run_with_threaded_pg
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -67,11 +64,10 @@ TEST_SKIPS = {
"generic": TestSkip(
86, "Test skipped at subprocess level, look at subprocess log for skip reason"
),
"importerror": TestSkip(
88, "Test skipped due to missing import"
),
"importerror": TestSkip(88, "Test skipped due to missing import"),
}
@dataclass
class DistTestCases:
# Backends that do not support a specific collective
@ -93,6 +89,7 @@ class DistTestCases:
def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs, ensuring that if the
test is run, each rank has its own GPU via ``torch.cuda.device(rank)``."""
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
@ -116,6 +113,7 @@ def skip_if_small_worldsize(func):
return wrapper
def skip_if_odd_worldsize(func):
@wraps(func)
def wrapper(*args, **kwargs):
@ -126,6 +124,7 @@ def skip_if_odd_worldsize(func):
return wrapper
def require_n_gpus_for_nccl_backend(n, backend):
def decorator(func):
@wraps(func)
@ -139,12 +138,17 @@ def require_n_gpus_for_nccl_backend(n, backend):
return decorator
def import_transformers_or_skip():
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
from transformers import BertConfig, AutoModelForMaskedLM # noqa: Unused
from transformers import ( # noqa: Unused
AutoModelForMaskedLM,
BertConfig,
)
return func(*args, **kwargs)
except ImportError:
sys.exit(TEST_SKIPS["importerror"].exit_code)
@ -153,6 +157,7 @@ def import_transformers_or_skip():
return decorator
def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
@ -191,10 +196,13 @@ def verify_ddp_error_logged(model_DDP, err_substr):
logging_err = ddp_logging_data["error"]
# Remove C++ stacktrace if needed.
actual = (
err_substr if err_substr.find("\nException raised from ") == -1
err_substr
if err_substr.find("\nException raised from ") == -1
else err_substr.split("\nException raised from ")[0]
)
assert actual in logging_err, f"Did not find expected {actual} in ddp logging data error: {logging_err}"
assert (
actual in logging_err
), f"Did not find expected {actual} in ddp logging data error: {logging_err}"
def with_nccl_blocking_wait(func):
@ -319,7 +327,7 @@ def skip_if_rocm(func):
def skip_if_win32():
return sandcastle_skip_if(
sys.platform == 'win32',
sys.platform == "win32",
"This unit test case is not supportted on Windows platform",
)
@ -352,13 +360,14 @@ if TEST_WITH_TSAN:
# TSAN runs much slower.
TIMEOUT_DEFAULT = 500
else:
TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300'))
TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
# https://github.com/pytorch/pytorch/issues/75665
if TEST_WITH_ROCM:
TIMEOUT_OVERRIDE["test_join_kwargs"] = 200
def create_device(interface=None):
if sys.platform == "win32" or interface is None:
return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1")
@ -449,9 +458,7 @@ def init_multigpu_helper(world_size: int, backend: str):
if world_size > nGPUs:
nGPUs_per_process = nGPUs // world_size
rank_to_GPU = {
i: list(
visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]
)
i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process])
for i in range(world_size)
}
return rank_to_GPU
@ -482,6 +489,9 @@ def cleanup_temp_dir() -> None:
tmp_dir.cleanup()
# Most tests operate with this worldsize
DEFAULT_WORLD_SIZE = 4
# [How does MultiProcessTestCase work?]
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
# default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an
@ -508,7 +518,7 @@ class MultiProcessTestCase(TestCase):
@property
def world_size(self) -> int:
return 4
return DEFAULT_WORLD_SIZE
def join_or_run(self, fn):
@wraps(fn)
@ -607,7 +617,10 @@ class MultiProcessTestCase(TestCase):
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
# Enable DDP + ReplicatedTensor
from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
from torch.nn.parallel._replicated_tensor_ddp_utils import (
_set_ddp_with_replicated_tensor,
)
_set_ddp_with_replicated_tensor(True)
self = cls(test_name)
@ -815,16 +828,20 @@ class MultiProcessTestCase(TestCase):
self.assertEqual(
first_process.exitcode,
0,
msg="Expected zero exit code but got {} for pid: {}".format(first_process.exitcode, first_process.pid)
msg="Expected zero exit code but got {} for pid: {}".format(
first_process.exitcode, first_process.pid
),
)
@property
def is_master(self) -> bool:
return self.rank == 0
# Cannot use functools.cache as it requires python 3.9
EFA_PROBE_RESULT = None
def has_efa() -> bool:
"""
If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has
@ -836,7 +853,9 @@ def has_efa() -> bool:
return EFA_PROBE_RESULT
try:
EFA_PROBE_RESULT = subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0
EFA_PROBE_RESULT = (
subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0
)
except FileNotFoundError:
EFA_PROBE_RESULT = False
return EFA_PROBE_RESULT
@ -850,3 +869,81 @@ def tp_transports():
see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
"""
return ["shm", "uv"] if has_efa() else None
def _run_test_with_mt_pg(self, timeout, world_size, callback):
failed_ranks = run_with_threaded_pg(world_size, timeout, callback)
for rank, exc_info in failed_ranks:
print(f"Rank {rank} raised:")
for line in traceback.format_exception(*exc_info):
sys.stdout.write(line)
self.assertEqual([], failed_ranks, "Some ranks failed")
def spawn_threads_and_init_comms(
func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE
):
"""
Wrapper to use with a test method
"""
if func is None:
return partial(
spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
)
@wraps(func)
def wrapper(self, *args, **kwargs):
_run_test_with_mt_pg(
self, timeout, world_size, lambda: func(self, *args, **kwargs)
)
return wrapper
class MultiThreadedTestCase(TestCase):
"""
Simple test runner that executes all tests with the in-proc process group.
A single instance of the TestCase object for all threads.
Difference from regular test runner:
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
Not sure what these two would be good for though.
No global state possible
How bad of a limitation is this?
"""
def __init__(self, method_name: str = "runTest") -> None:
super().__init__(method_name)
self._test_method = getattr(self, method_name, None)
setattr(self, method_name, self.threaded_run_test)
if TestCase.setUp != type(self).setUp:
raise RuntimeError(
f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead"
)
if TestCase.tearDown != type(self).tearDown:
raise RuntimeError(
f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead"
)
def threaded_run_test(self):
self.perThreadSetUp()
try:
_run_test_with_mt_pg(
self=self,
timeout=TIMEOUT_DEFAULT,
world_size=self.world_size,
callback=self._test_method,
)
finally:
self.perThreadTearDown()
def perThreadSetUp(self):
pass
def perThreadTearDown(self):
pass
@property
def world_size(self) -> int:
raise RuntimeError("world size not implemented")