mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[BE][lint] fix PYFMT for PT-D code under torch.testing._internal, add them to the lint list (#153114)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153114 Approved by: https://github.com/cyyever, https://github.com/fegin, https://github.com/H-Huang, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
2926dd4d8e
commit
0f9821d0e3
@ -1327,7 +1327,6 @@ exclude_patterns = [
|
||||
'torch/testing/_internal/codegen/__init__.py',
|
||||
'torch/testing/_internal/codegen/random_topo_test.py',
|
||||
'torch/testing/_internal/common_cuda.py',
|
||||
'torch/testing/_internal/common_distributed.py',
|
||||
'torch/testing/_internal/common_jit.py',
|
||||
'torch/testing/_internal/common_methods_invocations.py',
|
||||
'torch/testing/_internal/common_modules.py',
|
||||
@ -1344,38 +1343,6 @@ exclude_patterns = [
|
||||
'torch/testing/_internal/data/network1.py',
|
||||
'torch/testing/_internal/data/network2.py',
|
||||
'torch/testing/_internal/dist_utils.py',
|
||||
'torch/testing/_internal/distributed/__init__.py',
|
||||
'torch/testing/_internal/distributed/_shard/__init__.py',
|
||||
'torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py',
|
||||
'torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py',
|
||||
'torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py',
|
||||
'torch/testing/_internal/distributed/_shard/test_common.py',
|
||||
'torch/testing/_internal/distributed/_tensor/__init__.py',
|
||||
'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
|
||||
'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
|
||||
'torch/testing/_internal/distributed/distributed_test.py',
|
||||
'torch/testing/_internal/distributed/distributed_utils.py',
|
||||
'torch/testing/_internal/distributed/fake_pg.py',
|
||||
'torch/testing/_internal/distributed/multi_threaded_pg.py',
|
||||
'torch/testing/_internal/distributed/nn/__init__.py',
|
||||
'torch/testing/_internal/distributed/nn/api/__init__.py',
|
||||
'torch/testing/_internal/distributed/nn/api/remote_module_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/__init__.py',
|
||||
'torch/testing/_internal/distributed/rpc/dist_autograd_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/examples/__init__.py',
|
||||
'torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py',
|
||||
'torch/testing/_internal/distributed/rpc/jit/__init__.py',
|
||||
'torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/jit/rpc_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py',
|
||||
'torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py',
|
||||
'torch/testing/_internal/distributed/rpc/rpc_test.py',
|
||||
'torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py',
|
||||
'torch/testing/_internal/distributed/rpc_utils.py',
|
||||
'torch/testing/_internal/generated/__init__.py',
|
||||
'torch/testing/_internal/hypothesis_utils.py',
|
||||
'torch/testing/_internal/inductor_utils.py',
|
||||
|
||||
@ -5,6 +5,7 @@ import faulthandler
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import operator
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
@ -21,37 +22,37 @@ from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial, reduce, wraps
|
||||
from io import StringIO
|
||||
from typing import NamedTuple, Optional, Union, Any, Callable
|
||||
from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch._logging._internal import trace_log
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch.cuda.nccl
|
||||
import torch.distributed as c10d
|
||||
import torch.nn as nn
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
import torch.nn as nn
|
||||
from torch._logging._internal import trace_log
|
||||
from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
find_free_port,
|
||||
IS_SANDCASTLE,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_HPU,
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TSAN,
|
||||
TestCase,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.multi_threaded_pg import (
|
||||
_install_threaded_pg,
|
||||
_uninstall_threaded_pg,
|
||||
ProcessLocalGroup,
|
||||
)
|
||||
import operator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
@ -178,10 +179,7 @@ def import_transformers_or_skip():
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
from transformers import ( # noqa: F401
|
||||
AutoModelForMaskedLM,
|
||||
BertConfig,
|
||||
)
|
||||
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
|
||||
|
||||
return func(*args, **kwargs)
|
||||
except ImportError:
|
||||
@ -344,12 +342,14 @@ def requires_nccl():
|
||||
"c10d was not compiled with the NCCL backend",
|
||||
)
|
||||
|
||||
|
||||
def requires_ucc():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_ucc_available(),
|
||||
"c10d was not compiled with the UCC backend",
|
||||
)
|
||||
|
||||
|
||||
def requires_mpi():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_mpi_available(),
|
||||
@ -425,7 +425,12 @@ def create_tcp_store(
|
||||
)
|
||||
else:
|
||||
return c10d.TCPStore(
|
||||
addr, port, world_size, is_master, wait_for_workers=wait_for_workers, use_libuv=use_libuv
|
||||
addr,
|
||||
port,
|
||||
world_size,
|
||||
is_master,
|
||||
wait_for_workers=wait_for_workers,
|
||||
use_libuv=use_libuv,
|
||||
)
|
||||
|
||||
|
||||
@ -433,7 +438,7 @@ if TEST_WITH_TSAN:
|
||||
# TSAN runs much slower.
|
||||
TIMEOUT_DEFAULT = 500
|
||||
else:
|
||||
TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300'))
|
||||
TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
|
||||
TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
|
||||
|
||||
|
||||
@ -444,9 +449,13 @@ if TEST_WITH_ROCM:
|
||||
|
||||
def create_device(interface=None, lazy_init: bool = False):
|
||||
if sys.platform == "win32" or interface is None:
|
||||
return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1", lazy_init=lazy_init)
|
||||
return c10d.ProcessGroupGloo.create_device(
|
||||
hostname="127.0.0.1", lazy_init=lazy_init
|
||||
)
|
||||
else:
|
||||
return c10d.ProcessGroupGloo.create_device(interface=interface, lazy_init=lazy_init)
|
||||
return c10d.ProcessGroupGloo.create_device(
|
||||
interface=interface, lazy_init=lazy_init
|
||||
)
|
||||
|
||||
|
||||
def get_timeout(test_id) -> int:
|
||||
@ -612,7 +621,9 @@ class MultiProcessTestCase(TestCase):
|
||||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
# or run the underlying test function.
|
||||
def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
|
||||
def __init__(
|
||||
self, method_name: str = "runTest", methodName: str = "runTest"
|
||||
) -> None:
|
||||
# methodName is the correct naming in unittest and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != "runTest":
|
||||
@ -622,10 +633,12 @@ class MultiProcessTestCase(TestCase):
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != 'runTest':
|
||||
if methodName != "runTest":
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e
|
||||
raise ValueError(
|
||||
f"no such test method in {self.__class__}: {methodName}"
|
||||
) from e
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -660,7 +673,7 @@ class MultiProcessTestCase(TestCase):
|
||||
args=(rank, self._current_test_name(), self.file_name, child_conn),
|
||||
kwargs={
|
||||
"fake_pg": getattr(self, "fake_pg", False),
|
||||
}
|
||||
},
|
||||
)
|
||||
process.start()
|
||||
logger.info("Started process %s with pid %s", rank, process.pid)
|
||||
@ -681,10 +694,10 @@ class MultiProcessTestCase(TestCase):
|
||||
ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
|
||||
|
||||
if parent_pipe in ready_pipes:
|
||||
|
||||
if parent_pipe.closed:
|
||||
logger.info(
|
||||
"Pipe closed for process %s, stopping event listener thread", rank
|
||||
"Pipe closed for process %s, stopping event listener thread",
|
||||
rank,
|
||||
)
|
||||
return
|
||||
|
||||
@ -706,7 +719,9 @@ class MultiProcessTestCase(TestCase):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
|
||||
def _run(
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||
) -> None:
|
||||
self = cls(test_name)
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
@ -734,14 +749,18 @@ class MultiProcessTestCase(TestCase):
|
||||
getattr(self, test_name)()
|
||||
except unittest.SkipTest as se:
|
||||
logger.info(
|
||||
"Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se)
|
||||
"Process %s skipping test %s for following reason: %s",
|
||||
self.rank,
|
||||
test_name,
|
||||
str(se),
|
||||
)
|
||||
sys.exit(TEST_SKIPS["generic"].exit_code)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Caught exception: \n%s exiting "
|
||||
"process %s with exit code: %s",
|
||||
traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE
|
||||
"Caught exception: \n%s exiting " "process %s with exit code: %s",
|
||||
traceback.format_exc(),
|
||||
self.rank,
|
||||
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
|
||||
)
|
||||
# Send error to parent process.
|
||||
parent_pipe.send(traceback.format_exc())
|
||||
@ -773,7 +792,9 @@ class MultiProcessTestCase(TestCase):
|
||||
pipes.append((i, pipe))
|
||||
except ConnectionError as e:
|
||||
logger.error(
|
||||
"Encountered error while trying to get traceback for process %s: %s", i, e
|
||||
"Encountered error while trying to get traceback for process %s: %s",
|
||||
i,
|
||||
e,
|
||||
)
|
||||
|
||||
# Wait for results.
|
||||
@ -783,7 +804,8 @@ class MultiProcessTestCase(TestCase):
|
||||
if pipe.poll(5):
|
||||
if pipe.closed:
|
||||
logger.info(
|
||||
"Pipe closed for process %s, cannot retrieve traceback", rank
|
||||
"Pipe closed for process %s, cannot retrieve traceback",
|
||||
rank,
|
||||
)
|
||||
continue
|
||||
|
||||
@ -797,7 +819,9 @@ class MultiProcessTestCase(TestCase):
|
||||
)
|
||||
except ConnectionError as e:
|
||||
logger.error(
|
||||
"Encountered error while trying to get traceback for process %s: %s", rank, e
|
||||
"Encountered error while trying to get traceback for process %s: %s",
|
||||
rank,
|
||||
e,
|
||||
)
|
||||
|
||||
def _join_processes(self, fn) -> None:
|
||||
@ -807,7 +831,7 @@ class MultiProcessTestCase(TestCase):
|
||||
try:
|
||||
while True:
|
||||
# check to see if any subprocess exited with an error early.
|
||||
for (i, p) in enumerate(self.processes):
|
||||
for i, p in enumerate(self.processes):
|
||||
# This is the exit code processes exit with if they
|
||||
# encountered an exception.
|
||||
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
|
||||
@ -866,7 +890,9 @@ class MultiProcessTestCase(TestCase):
|
||||
"""
|
||||
# If no processes are spawned, there is nothing to check.
|
||||
if not self.processes:
|
||||
logger.warning("Note: no subprocesses were spawned, test was likely skipped.")
|
||||
logger.warning(
|
||||
"Note: no subprocesses were spawned, test was likely skipped."
|
||||
)
|
||||
return
|
||||
|
||||
first_process = self.processes[0]
|
||||
@ -912,7 +938,9 @@ class MultiProcessTestCase(TestCase):
|
||||
# is some follow-up needed. Instead just "pass" the test
|
||||
# with an appropriate message.
|
||||
logger.info(
|
||||
"Skipping %s on sandcastle for the following reason: %s", self.id(), skip.message
|
||||
"Skipping %s on sandcastle for the following reason: %s",
|
||||
self.id(),
|
||||
skip.message,
|
||||
)
|
||||
return
|
||||
else:
|
||||
@ -927,13 +955,13 @@ class MultiProcessTestCase(TestCase):
|
||||
def is_master(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
|
||||
# Utility base class for distributed Multi Process Test cases
|
||||
# This abstracts the PG creation and deletion, the backends are selected based
|
||||
# on device type. The tests functions can be instantiated per device type using
|
||||
# common_device_type.instantiate_device_type_tests
|
||||
# other backends can add entry in backend() function
|
||||
class DistributedTestBase(MultiProcessTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
@ -947,11 +975,11 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
def backend(self, device) -> str:
|
||||
if "cuda" in device:
|
||||
return "nccl"
|
||||
elif "hpu" in device : # intel gaudi
|
||||
elif "hpu" in device: # intel gaudi
|
||||
return "hccl"
|
||||
elif "xpu" in device:
|
||||
return "xccl"
|
||||
else :
|
||||
else:
|
||||
return "gloo"
|
||||
|
||||
def create_pg(self, device):
|
||||
@ -961,7 +989,7 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
backend=self.backend(device),
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store
|
||||
store=store,
|
||||
)
|
||||
if "nccl" in self.backend(device) or "xccl" in self.backend(device):
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
@ -971,6 +999,7 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
num_visible_devices = torch.get_device_module(device).device_count()
|
||||
return {i: [i % num_visible_devices] for i in range(self.world_size)}
|
||||
|
||||
|
||||
def run_subtests(
|
||||
cls_inst,
|
||||
subtest_config: dict[str, list[Any]],
|
||||
@ -1021,7 +1050,10 @@ def has_efa() -> bool:
|
||||
|
||||
try:
|
||||
EFA_PROBE_RESULT = (
|
||||
subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False).returncode == 0
|
||||
subprocess.run(
|
||||
["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False
|
||||
).returncode
|
||||
== 0
|
||||
)
|
||||
except FileNotFoundError:
|
||||
EFA_PROBE_RESULT = False
|
||||
@ -1049,7 +1081,6 @@ def spawn_threads_and_init_comms(
|
||||
spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
|
||||
)
|
||||
|
||||
|
||||
def _run_test_method_with_multi_threads(world_size, callback):
|
||||
world = _install_threaded_pg()
|
||||
global_store = c10d.HashStore()
|
||||
@ -1066,7 +1097,9 @@ def spawn_threads_and_init_comms(
|
||||
except BaseException as ex:
|
||||
# Exceptions are handled in MultiThreadedTestCase
|
||||
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
|
||||
ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads
|
||||
ProcessLocalGroup.exception_handle(
|
||||
ex
|
||||
) # trigger _terminate event and awaken worker threads
|
||||
finally:
|
||||
if world_is_valid():
|
||||
c10d.destroy_process_group()
|
||||
@ -1079,13 +1112,14 @@ def spawn_threads_and_init_comms(
|
||||
|
||||
return threads
|
||||
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# TODO: get test name from kwargs
|
||||
torch._C._distributed_c10d._set_thread_isolation_mode(True)
|
||||
try:
|
||||
threads = _run_test_method_with_multi_threads(world_size, lambda: func(self, *args, **kwargs))
|
||||
threads = _run_test_method_with_multi_threads(
|
||||
world_size, lambda: func(self, *args, **kwargs)
|
||||
)
|
||||
# join and error handling
|
||||
MultiThreadedTestCase._join_threads(threads, func)
|
||||
finally:
|
||||
@ -1108,6 +1142,7 @@ class MultiThreadedTestCase(TestCase):
|
||||
No global state possible
|
||||
How bad of a limitation is this?
|
||||
"""
|
||||
|
||||
exception_queue = queue.Queue()
|
||||
|
||||
MAIN_THREAD_RANK = -1
|
||||
@ -1122,7 +1157,9 @@ class MultiThreadedTestCase(TestCase):
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
|
||||
def __init__(
|
||||
self, method_name: str = "runTest", methodName: str = "runTest"
|
||||
) -> None:
|
||||
# methodName is the correct naming in unittest and testslide uses keyword arguments.
|
||||
# So we need to use both to 1) not break BC and, 2) support testslide.
|
||||
if methodName != "runTest":
|
||||
@ -1132,10 +1169,12 @@ class MultiThreadedTestCase(TestCase):
|
||||
fn = getattr(self, method_name)
|
||||
setattr(self, method_name, self.join_or_run(fn))
|
||||
except AttributeError as e:
|
||||
if methodName != 'runTest':
|
||||
if methodName != "runTest":
|
||||
# we allow instantiation with no explicit method name
|
||||
# but not an *incorrect* or missing method name
|
||||
raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e
|
||||
raise ValueError(
|
||||
f"no such test method in {self.__class__}: {methodName}"
|
||||
) from e
|
||||
|
||||
def perThreadSetUp(self):
|
||||
# super().setUp() # TestCase.setUp() calls torch.manual_seed()
|
||||
@ -1180,7 +1219,9 @@ class MultiThreadedTestCase(TestCase):
|
||||
raise RuntimeError("Invalid world")
|
||||
|
||||
for rank in range(self.world_size):
|
||||
t = threading.Thread(target=self.__class__._run, args=(test_name, rank, self.world_size))
|
||||
t = threading.Thread(
|
||||
target=self.__class__._run, args=(test_name, rank, self.world_size)
|
||||
)
|
||||
t.start()
|
||||
self.threads.append(t)
|
||||
|
||||
@ -1205,7 +1246,10 @@ class MultiThreadedTestCase(TestCase):
|
||||
Run the current test associated with `test_name` using the threaded process group.
|
||||
"""
|
||||
c10d.init_process_group(
|
||||
backend="threaded", rank=rank, world_size=world_size, store=self.__class__.global_store
|
||||
backend="threaded",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=self.__class__.global_store,
|
||||
)
|
||||
self.perThreadSetUp()
|
||||
|
||||
@ -1213,12 +1257,13 @@ class MultiThreadedTestCase(TestCase):
|
||||
getattr(self, test_name)()
|
||||
except BaseException as ex:
|
||||
self.exception_queue.put((rank, sys.exc_info()))
|
||||
ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads
|
||||
ProcessLocalGroup.exception_handle(
|
||||
ex
|
||||
) # trigger _terminate event and awaken worker threads
|
||||
finally:
|
||||
c10d.destroy_process_group()
|
||||
self.perThreadTearDown()
|
||||
|
||||
|
||||
@classmethod
|
||||
def _join_threads(cls, threads, fn):
|
||||
timeout = TIMEOUT_DEFAULT
|
||||
@ -1262,7 +1307,10 @@ class MultiThreadedTestCase(TestCase):
|
||||
exc = exc_info[1]
|
||||
if isinstance(exc, unittest.SkipTest):
|
||||
logger.info(
|
||||
"Thread %s skipping test %s for following reason: %s", rank, fn, str(exc)
|
||||
"Thread %s skipping test %s for following reason: %s",
|
||||
rank,
|
||||
fn,
|
||||
str(exc),
|
||||
)
|
||||
if skip_code < 0:
|
||||
skip_code = TEST_SKIPS["generic"].exit_code
|
||||
@ -1272,12 +1320,8 @@ class MultiThreadedTestCase(TestCase):
|
||||
raise RuntimeError(msg)
|
||||
elif isinstance(exc, Exception):
|
||||
msg = "".join(traceback.format_exception(*exc_info))
|
||||
logger.error(
|
||||
"Caught exception: \n%s exiting thread %s", msg, rank
|
||||
)
|
||||
error_msg += (
|
||||
f"Thread {rank} exited with exception:\n{msg}\n"
|
||||
)
|
||||
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
|
||||
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
|
||||
elif isinstance(exc, SystemExit):
|
||||
if type(exc.code) == int and skip_code < 0:
|
||||
skip_code = exc.code
|
||||
@ -1292,7 +1336,9 @@ class MultiThreadedTestCase(TestCase):
|
||||
if IS_SANDCASTLE:
|
||||
# "pass" the test with an appropriate message.
|
||||
logger.info(
|
||||
"Skipping %s on sandcastle for the following reason: %s", fn, skip.message
|
||||
"Skipping %s on sandcastle for the following reason: %s",
|
||||
fn,
|
||||
skip.message,
|
||||
)
|
||||
return
|
||||
else:
|
||||
@ -1352,14 +1398,15 @@ class SaveForwardInputsModel(nn.Module):
|
||||
self.forward_inputs[self] = x
|
||||
return self.c2(self.c1(x))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
|
||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||
if not fake_pg:
|
||||
torch.accelerator.set_device_index(rank)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '6789'
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
if init_pg:
|
||||
if fake_pg:
|
||||
store = torch.testing._internal.distributed.fake_pg.FakeStore()
|
||||
@ -1424,6 +1471,7 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
||||
Prefer MultiThreadedTestCase for most tests. Perhaps use this one
|
||||
sparingly for integration tests.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
@ -1440,7 +1488,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
|
||||
return torch.cuda.device_count()
|
||||
|
||||
@classmethod
|
||||
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
|
||||
def _run(
|
||||
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
|
||||
) -> None:
|
||||
trace_log.addHandler(logging.NullHandler())
|
||||
|
||||
# The rest is copypasta from MultiProcessTestCase._run
|
||||
@ -1461,7 +1511,6 @@ class MultiProcContinousTest(TestCase):
|
||||
# timeout configured per class
|
||||
timeout: timedelta = timedelta(seconds=120)
|
||||
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def backend_str(cls) -> str:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import sys
|
||||
from functools import wraps, partial
|
||||
from functools import partial, wraps
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -12,8 +12,10 @@ from torch.testing._internal.common_distributed import (
|
||||
tp_transports,
|
||||
)
|
||||
|
||||
|
||||
TEST_GPU_NUM = 4
|
||||
|
||||
|
||||
class ShardedTensorTestBase(MultiProcessTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
@ -34,9 +36,10 @@ class ShardedTensorTestBase(MultiProcessTestCase):
|
||||
if backend == "nccl":
|
||||
torch.cuda.set_device(self.rank)
|
||||
|
||||
|
||||
def init_rpc(self):
|
||||
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports())
|
||||
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
|
||||
_transports=tp_transports()
|
||||
)
|
||||
rpc_backend_options.init_method = f"file://{self.file_name}"
|
||||
for rank in range(self.world_size):
|
||||
rpc_backend_options.set_device_map(
|
||||
@ -79,6 +82,7 @@ class ShardedTensorTestBase(MultiProcessTestCase):
|
||||
self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
|
||||
self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
|
||||
|
||||
|
||||
# wrapper to initialize comms (processgroup + rpc)
|
||||
def with_comms(func=None, init_rpc=True, backend="nccl"):
|
||||
if func is None:
|
||||
@ -95,4 +99,5 @@ def with_comms(func=None, init_rpc=True, backend="nccl"):
|
||||
self.init_comms(init_rpc=init_rpc, backend=backend)
|
||||
func(self, *args, **kwargs)
|
||||
self.destroy_comms(destroy_rpc=init_rpc)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -121,16 +121,17 @@ def clone_module_parameter(module, param_name):
|
||||
tensor = getattr(module, param_name)
|
||||
return torch.nn.Parameter(tensor.detach().clone())
|
||||
|
||||
def gen_binary_op_func(python_op, inplace=False):
|
||||
src_lines = ['def f(lhs, rhs):']
|
||||
if "torch" in python_op:
|
||||
src_lines.append(f' return {python_op}(lhs, rhs)\n')
|
||||
elif inplace:
|
||||
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
|
||||
else:
|
||||
src_lines.append(f' return lhs {python_op} rhs\n')
|
||||
|
||||
code_str = '\n'.join(src_lines)
|
||||
g = {'torch': torch}
|
||||
def gen_binary_op_func(python_op, inplace=False):
|
||||
src_lines = ["def f(lhs, rhs):"]
|
||||
if "torch" in python_op:
|
||||
src_lines.append(f" return {python_op}(lhs, rhs)\n")
|
||||
elif inplace:
|
||||
src_lines.append(f" lhs {python_op}= rhs\n return lhs\n")
|
||||
else:
|
||||
src_lines.append(f" return lhs {python_op} rhs\n")
|
||||
|
||||
code_str = "\n".join(src_lines)
|
||||
g = {"torch": torch}
|
||||
builtins.exec(code_str, g)
|
||||
return g["f"]
|
||||
|
||||
@ -2,12 +2,11 @@
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
|
||||
PLACEMENTS = [
|
||||
"rank:0/cuda:0",
|
||||
@ -31,13 +30,9 @@ def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
|
||||
)
|
||||
return spec_list
|
||||
|
||||
|
||||
class MyShardedModel2(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec=None,
|
||||
group=None,
|
||||
init_rrefs=True
|
||||
) -> None:
|
||||
def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
|
||||
super().__init__()
|
||||
if spec is not None:
|
||||
self.sharded_tensor2 = sharded_tensor.rand(
|
||||
@ -49,12 +44,7 @@ class MyShardedModel2(torch.nn.Module):
|
||||
|
||||
|
||||
class MyShardedModel1(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec=None,
|
||||
group=None,
|
||||
init_rrefs=True
|
||||
) -> None:
|
||||
def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
|
||||
super().__init__()
|
||||
if spec is not None:
|
||||
self.sharded_tensor1 = sharded_tensor.rand(
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
|
||||
@ -4,22 +4,16 @@
|
||||
|
||||
import itertools
|
||||
import sys
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
TypeVar,
|
||||
Union,
|
||||
Optional,
|
||||
)
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
from torch.distributed.tensor.parallel import (
|
||||
@ -29,21 +23,16 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_HPU,
|
||||
TEST_CUDA,
|
||||
TEST_XPU
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
MultiThreadedTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
run_subtests,
|
||||
skip_if_lt_x_gpu,
|
||||
TEST_SKIPS,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
||||
from torch._utils import _get_device_module
|
||||
|
||||
|
||||
if TEST_CUDA:
|
||||
DEVICE_TYPE = "cuda"
|
||||
@ -232,20 +221,31 @@ class Transformer(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def parallelize(
|
||||
module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: bool = False
|
||||
module: "Transformer",
|
||||
device_mesh: DeviceMesh,
|
||||
use_seq_parallel: bool,
|
||||
local_output_for_attn: bool = False,
|
||||
) -> nn.Module:
|
||||
assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
|
||||
# Parallelize the root submodules.
|
||||
if use_seq_parallel:
|
||||
root_plan = {
|
||||
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
|
||||
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)),
|
||||
"tok_embeddings": RowwiseParallel(
|
||||
input_layouts=Replicate(), output_layouts=Shard(1)
|
||||
),
|
||||
"pos_embeddings": RowwiseParallel(
|
||||
input_layouts=Replicate(), output_layouts=Shard(0)
|
||||
),
|
||||
"norm": SequenceParallel(),
|
||||
}
|
||||
else:
|
||||
root_plan = {
|
||||
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
|
||||
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
|
||||
"tok_embeddings": RowwiseParallel(
|
||||
input_layouts=Replicate(), output_layouts=Replicate()
|
||||
),
|
||||
"pos_embeddings": RowwiseParallel(
|
||||
input_layouts=Replicate(), output_layouts=Replicate()
|
||||
),
|
||||
}
|
||||
|
||||
module_tp = parallelize_module(module, device_mesh, root_plan)
|
||||
@ -260,9 +260,15 @@ class Transformer(nn.Module):
|
||||
# shard the RMSNorms
|
||||
layer_parallelize_plan["attention_norm"] = SequenceParallel()
|
||||
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
|
||||
layer_parallelize_plan["attention.wq"] = ColwiseParallel(use_local_output=local_output_for_attn)
|
||||
layer_parallelize_plan["attention.wk"] = ColwiseParallel(use_local_output=local_output_for_attn)
|
||||
layer_parallelize_plan["attention.wv"] = ColwiseParallel(use_local_output=local_output_for_attn)
|
||||
layer_parallelize_plan["attention.wq"] = ColwiseParallel(
|
||||
use_local_output=local_output_for_attn
|
||||
)
|
||||
layer_parallelize_plan["attention.wk"] = ColwiseParallel(
|
||||
use_local_output=local_output_for_attn
|
||||
)
|
||||
layer_parallelize_plan["attention.wv"] = ColwiseParallel(
|
||||
use_local_output=local_output_for_attn
|
||||
)
|
||||
layer_parallelize_plan["attention.wo"] = (
|
||||
RowwiseParallel(output_layouts=Shard(1))
|
||||
if use_seq_parallel
|
||||
@ -297,7 +303,9 @@ class Transformer(nn.Module):
|
||||
|
||||
if local_output_for_attn:
|
||||
for layer in module_tp.layers:
|
||||
layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size()
|
||||
layer.attention.n_heads = (
|
||||
module_tp.model_args.n_heads // device_mesh.size()
|
||||
)
|
||||
|
||||
# Manually set output.weight so that parameters and gradients are shared.
|
||||
if module_tp.model_args.weight_tying:
|
||||
@ -327,7 +335,10 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
# if enough GPU we can use GPU, otherwise we fallback to CPU
|
||||
if not (TEST_CUDA or TEST_XPU) or torch.accelerator.device_count() < self.world_size:
|
||||
if (
|
||||
not (TEST_CUDA or TEST_XPU)
|
||||
or torch.accelerator.device_count() < self.world_size
|
||||
):
|
||||
return "cpu"
|
||||
else:
|
||||
return DEVICE_TYPE
|
||||
@ -344,7 +355,14 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
||||
|
||||
if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl"]:
|
||||
if self.backend not in [
|
||||
"nccl",
|
||||
"gloo",
|
||||
"mpi",
|
||||
"cpu:gloo,cuda:nccl",
|
||||
"hccl",
|
||||
"xccl",
|
||||
]:
|
||||
raise RuntimeError(f"Backend {self.backend} not supported!")
|
||||
|
||||
device_id = None
|
||||
@ -352,7 +370,9 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
# set device for nccl pg for collectives
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
# we only need to set device_id for nccl backend with eager init
|
||||
device_id = torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
|
||||
device_id = (
|
||||
torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
|
||||
)
|
||||
# For nccl backend, bind the device to the process if device_id is not None
|
||||
# so the nccl communicator is immediately formed and we can use `ncclCommSplit`
|
||||
# for form subgroup to avoid unnecesssary overhead.
|
||||
@ -371,7 +391,9 @@ class DTensorTestBase(MultiProcessTestCase):
|
||||
# FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
|
||||
# test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device() if self.device_type == "cuda" else self.rank
|
||||
device_id = (
|
||||
torch.cuda.current_device() if self.device_type == "cuda" else self.rank
|
||||
)
|
||||
dist.barrier(device_ids=[device_id])
|
||||
dist.destroy_process_group()
|
||||
|
||||
@ -398,14 +420,11 @@ TestFunc = Callable[[...], object]
|
||||
|
||||
# wrapper to initialize comms (processgroup)
|
||||
def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
|
||||
|
||||
def decorator(func, eager_init: bool = False):
|
||||
|
||||
@wraps(func) # pyre-ignore[6]
|
||||
def wrapper(
|
||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
||||
) -> None:
|
||||
|
||||
self.init_pg(eager_init)
|
||||
|
||||
try:
|
||||
@ -418,7 +437,11 @@ def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator(func=eager_init) if callable(eager_init) else partial(decorator, eager_init=eager_init)
|
||||
return (
|
||||
decorator(func=eager_init)
|
||||
if callable(eager_init)
|
||||
else partial(decorator, eager_init=eager_init)
|
||||
)
|
||||
|
||||
|
||||
class DTensorOpTestBase(MultiThreadedTestCase):
|
||||
@ -459,10 +482,16 @@ class DTensorConverter:
|
||||
self.flatten_kwargs: list[object] = flatten_kwargs
|
||||
self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
|
||||
|
||||
choices_for_args = [self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_args if isinstance(arg, torch.Tensor)]
|
||||
choices_for_args = [
|
||||
self.gen_sharding_choices_for_arg(arg)
|
||||
for arg in self.flatten_args
|
||||
if isinstance(arg, torch.Tensor)
|
||||
]
|
||||
|
||||
choices_for_args.extend(
|
||||
self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_kwargs if isinstance(arg, torch.Tensor)
|
||||
self.gen_sharding_choices_for_arg(arg)
|
||||
for arg in self.flatten_kwargs
|
||||
if isinstance(arg, torch.Tensor)
|
||||
)
|
||||
|
||||
self.sharding_combs: Iterator[Sequence[Placement]] = iter(
|
||||
|
||||
@ -20,7 +20,7 @@ from torch.testing._internal.common_distributed import (
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
)
|
||||
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
|
||||
from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
@ -68,7 +68,7 @@ gLogger = init_logger()
|
||||
|
||||
|
||||
class FeatureSet(NamedTuple):
|
||||
""" A feature set has 2 types of features"""
|
||||
"""A feature set has 2 types of features"""
|
||||
|
||||
dense_features: torch.Tensor
|
||||
sparse_features: torch.LongTensor
|
||||
@ -210,7 +210,8 @@ class Trainer:
|
||||
gLogger.info(
|
||||
"Succeeded in creating a HybridModel instance with "
|
||||
"%s ddp params and %s other local params.",
|
||||
len(self.ddp_params), len(self.non_ddp_params)
|
||||
len(self.ddp_params),
|
||||
len(self.non_ddp_params),
|
||||
)
|
||||
|
||||
def destroy_pg(self):
|
||||
@ -246,7 +247,8 @@ class Trainer:
|
||||
gLogger.info(
|
||||
"Trainer reduced input patches from %s "
|
||||
"to %s to simulate uneven inputs.",
|
||||
len(batches), len(input_batches)
|
||||
len(batches),
|
||||
len(input_batches),
|
||||
)
|
||||
else:
|
||||
input_batches = batches
|
||||
@ -260,7 +262,11 @@ class Trainer:
|
||||
grads_dict = dist_autograd.get_gradients(context_id)
|
||||
gLogger.info(
|
||||
"Loss is %s for mini batch: %s. "
|
||||
"Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), grads_dict
|
||||
"Grads dict has %s entries: %s",
|
||||
loss,
|
||||
mini_batch,
|
||||
len(grads_dict),
|
||||
grads_dict,
|
||||
)
|
||||
return (
|
||||
tuple(grads_dict[param] for param in self.ddp_params),
|
||||
@ -348,7 +354,9 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
|
||||
def _trainer_process(self, rank: int):
|
||||
gLogger.info("Running the trainer #%s...", rank)
|
||||
gLogger.info(
|
||||
"Initing trainer process group by trainer #%s with ranks %s", rank, TRAINER_RANKS
|
||||
"Initing trainer process group by trainer #%s with ranks %s",
|
||||
rank,
|
||||
TRAINER_RANKS,
|
||||
)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
@ -534,7 +542,9 @@ class DdpComparisonTest(CommonDdpComparisonTest):
|
||||
inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
|
||||
|
||||
if simulate_uneven_inputs:
|
||||
gLogger.info("Rank %s training with %s inputs.", self.rank, len(inputs_list))
|
||||
gLogger.info(
|
||||
"Rank %s training with %s inputs.", self.rank, len(inputs_list)
|
||||
)
|
||||
|
||||
# Use distributed autograd. The gradients will be in RPC context map.
|
||||
grads_dict = {}
|
||||
|
||||
@ -1,96 +1,93 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import copy
|
||||
import json
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import namedtuple, OrderedDict, defaultdict
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import reduce
|
||||
from typing import Union, NamedTuple, Callable, Any
|
||||
import unittest
|
||||
from typing import Any, Callable, NamedTuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
|
||||
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
|
||||
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
|
||||
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch._utils_internal import (
|
||||
TEST_MASTER_ADDR as MASTER_ADDR,
|
||||
TEST_MASTER_PORT as MASTER_PORT,
|
||||
)
|
||||
from torch.autograd import DeviceType
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as default,
|
||||
post_localSGD_hook as post_localSGD,
|
||||
powerSGD_hook as powerSGD,
|
||||
default_hooks as default,
|
||||
quantization as quantization_hooks,
|
||||
)
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
get_world_size,
|
||||
_get_default_group,
|
||||
_get_pg_config,
|
||||
get_world_size,
|
||||
)
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
from torch.distributed.utils import (
|
||||
_verify_param_shape_across_processes,
|
||||
_sync_module_states,
|
||||
_verify_param_shape_across_processes,
|
||||
)
|
||||
from torch.profiler import (
|
||||
ExecutionTraceObserver,
|
||||
ProfilerActivity,
|
||||
)
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
|
||||
from torch.profiler import ExecutionTraceObserver, ProfilerActivity
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
TEST_SKIPS,
|
||||
captured_output,
|
||||
cleanup_temp_dir,
|
||||
DistTestCases,
|
||||
init_multigpu_helper,
|
||||
initialize_temp_directories,
|
||||
cleanup_temp_dir,
|
||||
simple_sparse_reduce_tests,
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_small_worldsize,
|
||||
skip_if_odd_worldsize,
|
||||
skip_if_lt_x_gpu,
|
||||
MultiProcessTestCase,
|
||||
nccl_skip_if_lt_x_gpu,
|
||||
skip_if_no_gpu,
|
||||
require_n_gpus_for_nccl_backend,
|
||||
requires_nccl_version,
|
||||
captured_output,
|
||||
with_nccl_blocking_wait,
|
||||
with_dist_debug_levels,
|
||||
simple_sparse_reduce_tests,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_no_gpu,
|
||||
skip_if_odd_worldsize,
|
||||
skip_if_rocm_multiprocess,
|
||||
skip_if_small_worldsize,
|
||||
TEST_SKIPS,
|
||||
verify_ddp_error_logged,
|
||||
DistTestCases,
|
||||
with_dist_debug_levels,
|
||||
with_nccl_blocking_wait,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
FILE_SCHEMA,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
|
||||
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
import operator
|
||||
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
@ -201,19 +198,19 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
|
||||
else profiler.function_events
|
||||
)
|
||||
return [
|
||||
event for event in event_list
|
||||
event
|
||||
for event in event_list
|
||||
if (
|
||||
(event.name.endswith(event_name) or event.name.startswith(event_name))
|
||||
and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_profiler_nccl_meta(prof):
|
||||
"""Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
|
||||
We will need to test metadata obtained from profiler here"""
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w+t", suffix=".json", delete=False
|
||||
)
|
||||
tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False)
|
||||
tf.close()
|
||||
trace_file = tf.name
|
||||
|
||||
@ -227,6 +224,7 @@ def get_profiler_nccl_meta(prof):
|
||||
|
||||
return [e for e in events if e.get("name") == "record_param_comms"]
|
||||
|
||||
|
||||
# Base error message substring on unfinished reductions.
|
||||
ddp_prev_reduction_unfinished_str = (
|
||||
"Expected to have finished reduction in the prior iteration"
|
||||
@ -424,6 +422,7 @@ CUSTOM_PG_TIMEOUT = {
|
||||
"test_ddp_has_finalized": 5,
|
||||
}
|
||||
|
||||
|
||||
def require_backend_is_available(backends):
|
||||
def check(backend):
|
||||
if backend == dist.Backend.GLOO:
|
||||
@ -672,7 +671,7 @@ class DistributedTest:
|
||||
# Verify buffers across ranks.
|
||||
m1_buffers = list(m1.buffers())
|
||||
m2_buffers = list(m2.buffers())
|
||||
for (buf1, buf2) in zip(m1_buffers, m2_buffers):
|
||||
for buf1, buf2 in zip(m1_buffers, m2_buffers):
|
||||
gathered_bufs = [
|
||||
torch.empty_like(buf1) for _ in range(dist.get_world_size())
|
||||
]
|
||||
@ -1218,8 +1217,8 @@ class DistributedTest:
|
||||
|
||||
subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
|
||||
subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
|
||||
real_group_ranks_res1 = _get_pg_config(subgroup1)['ranks']
|
||||
real_group_ranks_res2 = _get_pg_config(subgroup2)['ranks']
|
||||
real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"]
|
||||
real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"]
|
||||
|
||||
expect_group_ranks_res1 = (
|
||||
rank // subgroup_size1 * subgroup_size1
|
||||
@ -1269,7 +1268,7 @@ class DistributedTest:
|
||||
@skip_if_no_gpu
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
|
||||
"Coalescing manager currently tests with NCCL only; internal test flaky"
|
||||
"Coalescing manager currently tests with NCCL only; internal test flaky",
|
||||
)
|
||||
def test_coalescing_manager(self):
|
||||
self._barrier()
|
||||
@ -1294,7 +1293,7 @@ class DistributedTest:
|
||||
for i in range(num_colls):
|
||||
self.assertEqual(
|
||||
small_tensors[i],
|
||||
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
|
||||
big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
|
||||
)
|
||||
|
||||
self._barrier()
|
||||
@ -1303,7 +1302,7 @@ class DistributedTest:
|
||||
@skip_if_no_gpu
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
|
||||
"Coalescing manager currently tests with NCCL only; internal test flaky"
|
||||
"Coalescing manager currently tests with NCCL only; internal test flaky",
|
||||
)
|
||||
def test_coalescing_manager_async(self):
|
||||
self._barrier()
|
||||
@ -1329,7 +1328,7 @@ class DistributedTest:
|
||||
for i in range(num_colls):
|
||||
self.assertEqual(
|
||||
small_tensors[i],
|
||||
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
|
||||
big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
|
||||
)
|
||||
|
||||
self._barrier()
|
||||
@ -1585,7 +1584,9 @@ class DistributedTest:
|
||||
backend = dist.get_backend()
|
||||
if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
|
||||
for event_name in [f"{backend}:send", f"{backend}:recv"]:
|
||||
events = get_profiling_event(event_name, prof, dedup_gpu_user_annotation=True)
|
||||
events = get_profiling_event(
|
||||
event_name, prof, dedup_gpu_user_annotation=True
|
||||
)
|
||||
self.assertTrue(events)
|
||||
# Event order is not deterministic, so simply assert their shape
|
||||
# is found in the following list.
|
||||
@ -1595,7 +1596,6 @@ class DistributedTest:
|
||||
for event in events:
|
||||
self.assertTrue(event.input_shapes in expected_shapes)
|
||||
|
||||
|
||||
@skip_if_no_gpu
|
||||
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
|
||||
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
|
||||
@ -2745,9 +2745,7 @@ class DistributedTest:
|
||||
]
|
||||
_group, group_id, _rank = self._init_global_test()
|
||||
for unsupported_op in unsupported_ops:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "all_reduce does not support"
|
||||
):
|
||||
with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
|
||||
dist.all_reduce(
|
||||
_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
|
||||
)
|
||||
@ -4373,9 +4371,7 @@ class DistributedTest:
|
||||
self.net2 = nn.Linear(10, 0)
|
||||
|
||||
model = ToyModel().to(self.rank)
|
||||
nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[self.rank]
|
||||
)
|
||||
nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
|
||||
def test_ddp_create_graph(self):
|
||||
@ -4521,7 +4517,7 @@ class DistributedTest:
|
||||
models_to_test.append(
|
||||
(torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
|
||||
)
|
||||
for (model, inp) in models_to_test:
|
||||
for model, inp in models_to_test:
|
||||
# Enable determinism in cudnn operators
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=True, deterministic=True, benchmark=False
|
||||
@ -4721,11 +4717,11 @@ class DistributedTest:
|
||||
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
|
||||
model, params_to_ignore
|
||||
)
|
||||
torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[self.rank]
|
||||
)
|
||||
dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
|
||||
model, named_params=True
|
||||
torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
|
||||
dp_params = (
|
||||
torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
|
||||
model, named_params=True
|
||||
)
|
||||
)
|
||||
for name, _ in dp_params:
|
||||
self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
|
||||
@ -4734,7 +4730,11 @@ class DistributedTest:
|
||||
# no of parameters.
|
||||
num_ddp_params = len(list(model.parameters())) - 1
|
||||
count = 0
|
||||
dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(model, named_params=False)
|
||||
dp_params = (
|
||||
torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
|
||||
model, named_params=False
|
||||
)
|
||||
)
|
||||
for _ in dp_params:
|
||||
count += 1
|
||||
self.assertEqual(count, num_ddp_params)
|
||||
@ -4902,7 +4902,8 @@ class DistributedTest:
|
||||
# Parameters to ignore are in the format {module_name}.{param_name}
|
||||
to_ignore = ["a.weight", "buffer"]
|
||||
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
|
||||
model, to_ignore,
|
||||
model,
|
||||
to_ignore,
|
||||
)
|
||||
mp_config = self._get_fp16_config()
|
||||
net = torch.nn.parallel.DistributedDataParallel(
|
||||
@ -4915,11 +4916,11 @@ class DistributedTest:
|
||||
expected_ignored = len(to_ignore)
|
||||
n_ignored = 0
|
||||
# ignored params should not have _mp_param or _fp_param fields.
|
||||
for (n, p) in itertools.chain(net.named_parameters(), net.named_buffers()):
|
||||
for n, p in itertools.chain(net.named_parameters(), net.named_buffers()):
|
||||
if n in to_ignore:
|
||||
n_ignored += 1
|
||||
self.assertFalse(hasattr(p, '_mp_param'))
|
||||
self.assertFalse(hasattr(p, '_fp_param'))
|
||||
self.assertFalse(hasattr(p, "_mp_param"))
|
||||
self.assertFalse(hasattr(p, "_fp_param"))
|
||||
else:
|
||||
self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
|
||||
self.assertEqual(torch.float32, p._fp_param.dtype)
|
||||
@ -4940,10 +4941,8 @@ class DistributedTest:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = torch.nn.Linear(1, 5)
|
||||
self.register_buffer('buffer', torch.randn(1, 2))
|
||||
self.p = torch.nn.Parameter(
|
||||
torch.randn(10, 5), requires_grad=False
|
||||
)
|
||||
self.register_buffer("buffer", torch.randn(1, 2))
|
||||
self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False)
|
||||
|
||||
def forward(self_, x): # noqa: B902
|
||||
params = self_.m.parameters()
|
||||
@ -4978,7 +4977,7 @@ class DistributedTest:
|
||||
for n, param in net.named_parameters():
|
||||
self.assertEqual(param.dtype, torch.float32)
|
||||
if param.grad is None:
|
||||
assert n == 'module.p' # Only param that doesn't require grad
|
||||
assert n == "module.p" # Only param that doesn't require grad
|
||||
else:
|
||||
self.assertEqual(param.grad.dtype, torch.float32)
|
||||
tensor_list = [
|
||||
@ -4994,7 +4993,9 @@ class DistributedTest:
|
||||
net.zero_grad(set_to_none=set_grad_to_none)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(self):
|
||||
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(
|
||||
self,
|
||||
):
|
||||
self._test_ddp_native_mixed_precision(
|
||||
gradient_as_bucket_view=False,
|
||||
set_grad_to_none=False,
|
||||
@ -5014,7 +5015,9 @@ class DistributedTest:
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(self):
|
||||
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(
|
||||
self,
|
||||
):
|
||||
self._test_ddp_native_mixed_precision(
|
||||
gradient_as_bucket_view=True, set_grad_to_none=True
|
||||
)
|
||||
@ -6079,7 +6082,9 @@ class DistributedTest:
|
||||
model = copy.deepcopy(BN_NET)
|
||||
model = model.half()
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = nn.parallel.DistributedDataParallel(model.cuda(rank), device_ids=[rank])
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model.cuda(rank), device_ids=[rank]
|
||||
)
|
||||
inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
|
||||
# Check that forward/backward do not error with dtype mismatch
|
||||
out = model(inp)
|
||||
@ -6858,7 +6863,9 @@ class DistributedTest:
|
||||
loss.backward()
|
||||
|
||||
all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
|
||||
events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
|
||||
events = get_profiling_event(
|
||||
all_reduce_event_name, prof, dedup_gpu_user_annotation=True
|
||||
)
|
||||
event_count = sum(e.count for e in events)
|
||||
self.assertEqual(event_count, num_iters)
|
||||
for event in events:
|
||||
@ -6866,7 +6873,9 @@ class DistributedTest:
|
||||
self.assertEqual(event.name, all_reduce_event_name)
|
||||
|
||||
broadcast_event_name = f"{dist.get_backend()}:broadcast"
|
||||
broadcast_events = get_profiling_event(broadcast_event_name, prof, dedup_gpu_user_annotation=True)
|
||||
broadcast_events = get_profiling_event(
|
||||
broadcast_event_name, prof, dedup_gpu_user_annotation=True
|
||||
)
|
||||
event_count = sum(e.count for e in broadcast_events)
|
||||
# Broadcast is called during rebuild_buckets
|
||||
self.assertGreaterEqual(event_count, 1)
|
||||
@ -6889,7 +6898,9 @@ class DistributedTest:
|
||||
loss = net(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
|
||||
events = get_profiling_event(
|
||||
all_reduce_event_name, prof, dedup_gpu_user_annotation=True
|
||||
)
|
||||
self.assertGreaterEqual(len(events), 1)
|
||||
self.assertGreaterEqual(events[0].count, 1)
|
||||
self.assertEqual(events[0].name, all_reduce_event_name)
|
||||
@ -6949,9 +6960,13 @@ class DistributedTest:
|
||||
"""
|
||||
with open(et_file) as f:
|
||||
et = json.load(f)
|
||||
pg_cfg_node = [n for n in et["nodes"] if n["name"] == "## process_group:init ##"]
|
||||
pg_cfg_node = [
|
||||
n for n in et["nodes"] if n["name"] == "## process_group:init ##"
|
||||
]
|
||||
self.assertGreaterEqual(len(pg_cfg_node), 1)
|
||||
nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"]
|
||||
nccl_meta_nodes = [
|
||||
n for n in et["nodes"] if n["name"] == "record_param_comms"
|
||||
]
|
||||
self.assertEqual(len(nccl_meta_nodes), 3)
|
||||
per_coll_meta = defaultdict(list)
|
||||
|
||||
@ -6969,7 +6984,7 @@ class DistributedTest:
|
||||
if collname in {"wait"}:
|
||||
continue
|
||||
|
||||
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
|
||||
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
|
||||
self.assertEqual(attrs["pg_desc"], "default_pg")
|
||||
self.assertEqual(attrs["pg_size"], 2)
|
||||
|
||||
@ -6992,7 +7007,6 @@ class DistributedTest:
|
||||
self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
|
||||
self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
|
||||
|
||||
|
||||
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
|
||||
@ -7016,7 +7030,7 @@ class DistributedTest:
|
||||
# collect ET in second profiler pass
|
||||
torch_profiler_ctx2 = torch.profiler.profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
execution_trace_observer=et
|
||||
execution_trace_observer=et,
|
||||
)
|
||||
self._test_ddp_profiling(
|
||||
profiler_ctx=torch_profiler_ctx1,
|
||||
@ -7026,7 +7040,6 @@ class DistributedTest:
|
||||
print(f"Execution trace saved at {fp.name}")
|
||||
self._validate_execution_trace_nccl(et_file)
|
||||
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
@ -7404,7 +7417,9 @@ class DistributedTest:
|
||||
for num_early_join_ranks in num_uneven_ranks:
|
||||
for baseline_iter in baseline_num_iters:
|
||||
for offset in iteration_offsets:
|
||||
mapping = dict.fromkeys(range(0, num_early_join_ranks), baseline_iter)
|
||||
mapping = dict.fromkeys(
|
||||
range(0, num_early_join_ranks), baseline_iter
|
||||
)
|
||||
# if num_early_join_ranks > 1, ranks > 0 that will join early
|
||||
# iterate offset//2 more times than rank 0, to test nodes
|
||||
# depleting inputs at different times.
|
||||
@ -7413,11 +7428,14 @@ class DistributedTest:
|
||||
if rank > 0:
|
||||
mapping[rank] += offset // 2
|
||||
mapping.update(
|
||||
dict.fromkeys(range(num_early_join_ranks, dist.get_world_size()), baseline_iter + offset)
|
||||
dict.fromkeys(
|
||||
range(num_early_join_ranks, dist.get_world_size()),
|
||||
baseline_iter + offset,
|
||||
)
|
||||
)
|
||||
iteration_mappings.append(mapping)
|
||||
|
||||
for (test_case, iteration_mapping) in itertools.product(
|
||||
for test_case, iteration_mapping in itertools.product(
|
||||
models_to_test, iteration_mappings
|
||||
):
|
||||
if self.rank == 0:
|
||||
@ -7570,7 +7588,9 @@ class DistributedTest:
|
||||
int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
|
||||
)
|
||||
@with_dist_debug_levels(levels=["DETAIL"])
|
||||
@unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
|
||||
@unittest.skip(
|
||||
"Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
|
||||
)
|
||||
def test_broadcast_object_list(self):
|
||||
return self._test_broadcast_object_list()
|
||||
|
||||
@ -7605,7 +7625,7 @@ class DistributedTest:
|
||||
|
||||
device_id = self.rank
|
||||
# Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
|
||||
for (find_unused, broadcast_buffers) in itertools.product(
|
||||
for find_unused, broadcast_buffers in itertools.product(
|
||||
[False, True], [False, True]
|
||||
):
|
||||
model = TestModel(self.rank).float().to(device_id)
|
||||
@ -7954,16 +7974,16 @@ class DistributedTest:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
input = torch.rand((1, ))
|
||||
input = torch.rand((1,))
|
||||
output = model.forward(input)
|
||||
target = torch.rand((1, ))
|
||||
target = torch.rand((1,))
|
||||
|
||||
loss = mse_loss(output, target)
|
||||
loss.backward()
|
||||
|
||||
self.assertTrue(
|
||||
not any(p.grad is None for p in model.parameters()),
|
||||
"Gradients can't be None for any model parameter."
|
||||
"Gradients can't be None for any model parameter.",
|
||||
)
|
||||
grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
|
||||
|
||||
@ -7978,7 +7998,7 @@ class DistributedTest:
|
||||
for g in gathered_grads[1:]:
|
||||
self.assertTrue(
|
||||
torch.allclose(gathered_grads[0], g),
|
||||
"Gradients are not the same for all ranks."
|
||||
"Gradients are not the same for all ranks.",
|
||||
)
|
||||
|
||||
@with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
|
||||
@ -8290,9 +8310,7 @@ class DistributedTest:
|
||||
tensors_sparse, [400], logger=net.logger
|
||||
)
|
||||
else:
|
||||
dist._compute_bucket_assignment_by_size(
|
||||
tensors_sparse, [400]
|
||||
)
|
||||
dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
|
||||
if use_logger:
|
||||
verify_ddp_error_logged(net, expected_err)
|
||||
|
||||
@ -8399,9 +8417,7 @@ class DistributedTest:
|
||||
# Creates network with different sized embedding table on different
|
||||
# ranks. This should throw an error during DDP init.
|
||||
net = EmbeddingNetDifferentParams(self.rank)
|
||||
self._run_test_ddp_model_with_diff_params(
|
||||
net, group_to_use, group_gloo
|
||||
)
|
||||
self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo)
|
||||
|
||||
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
@ -8516,7 +8532,7 @@ class DistributedTest:
|
||||
self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
|
||||
|
||||
# Verify grads are the same
|
||||
for (local_param, dist_param) in zip(
|
||||
for local_param, dist_param in zip(
|
||||
local_net.parameters(), net.parameters()
|
||||
):
|
||||
local_grad = local_param.grad
|
||||
@ -8935,7 +8951,9 @@ class DistributedTest:
|
||||
if ignore_sparse:
|
||||
for module_name, module in model.named_modules():
|
||||
if module == model.sub_module.embedding_net.embedding:
|
||||
for parameter_name, _param in module.named_parameters(recurse=False):
|
||||
for parameter_name, _param in module.named_parameters(
|
||||
recurse=False
|
||||
):
|
||||
fqn = f"{module_name}.{parameter_name}"
|
||||
sparse_embedding_fqns.append(fqn)
|
||||
|
||||
@ -9079,7 +9097,9 @@ class DistributedTest:
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
|
||||
@unittest.skip(
|
||||
"Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
|
||||
)
|
||||
def test_ddp_sync_bn_training_vs_eval(self):
|
||||
rank = self.rank
|
||||
torch.cuda.set_device(rank)
|
||||
@ -9228,7 +9248,7 @@ class DistributedTest:
|
||||
loss_static = get_loss(out_static)
|
||||
loss_static.backward()
|
||||
self._model_step(model_static_graph)
|
||||
for (p, p_static) in zip(
|
||||
for p, p_static in zip(
|
||||
model.parameters(), model_static_graph.parameters()
|
||||
):
|
||||
self.assertEqual(p, p_static)
|
||||
@ -9257,7 +9277,7 @@ class DistributedTest:
|
||||
|
||||
model = MyModel().to(self.rank)
|
||||
inp = torch.randn(1, 10, device=self.rank)
|
||||
for (find_unused, static_graph) in itertools.product(
|
||||
for find_unused, static_graph in itertools.product(
|
||||
[True, False], [True, False]
|
||||
):
|
||||
ddp = DistributedDataParallel(
|
||||
@ -9526,7 +9546,6 @@ class DistributedTest:
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_remove_autograd_hooks(self):
|
||||
|
||||
class SimulateError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
@ -9548,7 +9567,6 @@ class DistributedTest:
|
||||
else:
|
||||
return self.fc1(inp)
|
||||
|
||||
|
||||
# Run with error to trigger backward pass that marks fc1 as being marked
|
||||
# ready. If we don't remove autograd hooks before running below it would
|
||||
# fail on the old autograd hook.
|
||||
@ -9578,9 +9596,10 @@ class DistributedTest:
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
@unittest.skip("Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751")
|
||||
@unittest.skip(
|
||||
"Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751"
|
||||
)
|
||||
def test_ddp_has_finalized(self):
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
obj: torch.Tensor
|
||||
@ -9615,10 +9634,16 @@ class DistributedTest:
|
||||
(out1.sum() + out2.sum()).backward()
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration",
|
||||
):
|
||||
ddp._check_reducer_finalized()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to have finished reduction in the prior iteration",
|
||||
):
|
||||
ddp(input)
|
||||
else:
|
||||
ddp._check_reducer_finalized()
|
||||
@ -9901,14 +9926,11 @@ class DistributedTest:
|
||||
p.grad.data = p.grad / iters
|
||||
|
||||
for p_ddp, p_local in zip(
|
||||
model.parameters(),
|
||||
local_model.parameters()
|
||||
model.parameters(), local_model.parameters()
|
||||
):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
p_ddp.grad, p_local.grad
|
||||
),
|
||||
f"{p_ddp.grad} vs {p_local.grad}"
|
||||
torch.allclose(p_ddp.grad, p_local.grad),
|
||||
f"{p_ddp.grad} vs {p_local.grad}",
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
@ -10155,11 +10177,8 @@ class DistributedTest:
|
||||
f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
|
||||
)
|
||||
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
True, "Skipped due to flakiness"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness")
|
||||
def test_ddp_hook_pickling_powerSGD(self):
|
||||
|
||||
hook = powerSGD.powerSGD_hook
|
||||
powersgd_state = powerSGD.PowerSGDState(
|
||||
process_group=None,
|
||||
@ -10177,17 +10196,21 @@ class DistributedTest:
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
device_mesh = init_device_mesh("cuda", (world_size,))
|
||||
|
||||
pg = _get_default_group()
|
||||
|
||||
torch.cuda.set_device(self.rank)
|
||||
model = TwoLinLayerNet().cuda()
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_mesh=device_mesh)
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_mesh=device_mesh
|
||||
)
|
||||
self.assertEqual(ddp_model.device_mesh, device_mesh)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot specify both process_group and device_mesh arguments."
|
||||
RuntimeError,
|
||||
"Cannot specify both process_group and device_mesh arguments.",
|
||||
):
|
||||
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, process_group=pg, device_mesh=device_mesh
|
||||
@ -10201,7 +10224,6 @@ class DistributedTest:
|
||||
model, device_mesh=device_mesh
|
||||
)
|
||||
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@require_world_size(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
@ -10217,9 +10239,7 @@ class DistributedTest:
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ddp_static = torch.nn.parallel.DistributedDataParallel(
|
||||
model_clone,
|
||||
device_ids=[self.rank],
|
||||
static_graph=True
|
||||
model_clone, device_ids=[self.rank], static_graph=True
|
||||
)
|
||||
ddp = torch.compile(ddp)
|
||||
ddp_static = torch.compile(ddp_static)
|
||||
@ -10271,9 +10291,11 @@ class DistributedTest:
|
||||
with OpPatcher():
|
||||
ddp(input).sum().backward()
|
||||
|
||||
|
||||
def _test_skip_all_reduce_unused_parameters(
|
||||
self, find_unused_parameters=False, static_graph=False, skip_all_reduce_unused_params=False,
|
||||
self,
|
||||
find_unused_parameters=False,
|
||||
static_graph=False,
|
||||
skip_all_reduce_unused_params=False,
|
||||
):
|
||||
class LargeNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10312,11 +10334,15 @@ class DistributedTest:
|
||||
test_model_1 = self._test_skip_all_reduce_unused_parameters(
|
||||
find_unused_parameters=True,
|
||||
static_graph=False,
|
||||
skip_all_reduce_unused_params=True
|
||||
skip_all_reduce_unused_params=True,
|
||||
)
|
||||
|
||||
self.assertEqual(base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2)
|
||||
self.assertEqual(test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1)
|
||||
self.assertEqual(
|
||||
base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2
|
||||
)
|
||||
self.assertEqual(
|
||||
test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1
|
||||
)
|
||||
|
||||
for i, j in zip(base_model.parameters(), test_model_1.parameters()):
|
||||
self.assertEqual(i, j)
|
||||
|
||||
@ -2,26 +2,26 @@
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from functools import (
|
||||
partial,
|
||||
wraps,
|
||||
)
|
||||
from functools import partial, wraps
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
class MockProcessGroup(dist.ProcessGroup):
|
||||
|
||||
class MockProcessGroup(dist.ProcessGroup):
|
||||
def __init__(self, rank, world):
|
||||
super().__init__(rank, world)
|
||||
|
||||
def getBackendName(self):
|
||||
return "mock_process_group"
|
||||
|
||||
|
||||
def create_mock_pg(prefix_store, rank, world_size, timeout):
|
||||
return MockProcessGroup(rank, world_size)
|
||||
|
||||
dist.Backend.register_backend('mock_process_group', create_mock_pg)
|
||||
|
||||
dist.Backend.register_backend("mock_process_group", create_mock_pg)
|
||||
|
||||
|
||||
def mock_init_dist(rank, world_size):
|
||||
# !!! WARNING !!!
|
||||
@ -38,7 +38,9 @@ def mock_init_dist(rank, world_size):
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
group_name="fake",
|
||||
timeout=timedelta(seconds=1))
|
||||
timeout=timedelta(seconds=1),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dist(rank=0, world_size=2):
|
||||
@ -51,6 +53,7 @@ def with_dist(rank=0, world_size=2):
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def with_fake_comms(func=None, rank=0, world_size=2):
|
||||
"""
|
||||
Function wrapper that inits a fake process group designed for testing.
|
||||
@ -63,4 +66,5 @@ def with_fake_comms(func=None, rank=0, world_size=2):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with with_dist(rank, world_size):
|
||||
func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch._C._distributed_c10d import (
|
||||
FakeProcessGroup,
|
||||
)
|
||||
from torch._C._distributed_c10d import FakeProcessGroup
|
||||
|
||||
|
||||
class FakeStore(dist.Store):
|
||||
@ -28,4 +25,4 @@ def _create_fake_pg(prefix_store, rank, world_size, timeout):
|
||||
return FakeProcessGroup(rank, world_size)
|
||||
|
||||
|
||||
dist.Backend.register_backend("fake", _create_fake_pg, devices=['cpu', 'cuda'])
|
||||
dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda"])
|
||||
|
||||
@ -2,13 +2,13 @@
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from functools import partial, reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import weakref
|
||||
from torch._C._distributed_c10d import (
|
||||
_create_work_from_future,
|
||||
AllgatherOptions,
|
||||
@ -16,15 +16,16 @@ from torch._C._distributed_c10d import (
|
||||
AllToAllOptions,
|
||||
BarrierOptions,
|
||||
BroadcastOptions,
|
||||
ReduceOp,
|
||||
ReduceScatterOptions,
|
||||
ScatterOptions,
|
||||
Store,
|
||||
ReduceOp,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
|
||||
from torch.futures import Future
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
"""
|
||||
TODO:
|
||||
Lots of missing collectives.
|
||||
@ -45,6 +46,7 @@ def ret_work(ret):
|
||||
fut.set_result(ret)
|
||||
return _create_work_from_future(fut)
|
||||
|
||||
|
||||
def binop_reduce(tensors, op):
|
||||
res = op(torch.stack(tensors), dim=0)
|
||||
if isinstance(res, torch.Tensor):
|
||||
@ -52,9 +54,11 @@ def binop_reduce(tensors, op):
|
||||
# min/max return a namedtuple
|
||||
return res.values
|
||||
|
||||
|
||||
def bitwise_reduce(tensors, op):
|
||||
return reduce(op, tensors)
|
||||
|
||||
|
||||
_reduce_ops = {
|
||||
ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
|
||||
ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
|
||||
@ -66,6 +70,7 @@ _reduce_ops = {
|
||||
ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
|
||||
}
|
||||
|
||||
|
||||
class AllToAll:
|
||||
@torch.no_grad()
|
||||
def work(self, data):
|
||||
@ -76,6 +81,7 @@ class AllToAll:
|
||||
_, input_tensor_list = data[src_rank]
|
||||
output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
|
||||
|
||||
|
||||
class AllToAllBase:
|
||||
@torch.no_grad()
|
||||
def work(self, data):
|
||||
@ -83,35 +89,45 @@ class AllToAllBase:
|
||||
for dest_rank in range(world_size):
|
||||
output_buffer, _, output_split_sizes, _ = data[dest_rank]
|
||||
|
||||
output_indexes = self._size_cumsum(output_buffer.size(0), output_split_sizes, world_size)
|
||||
output_indexes = self._size_cumsum(
|
||||
output_buffer.size(0), output_split_sizes, world_size
|
||||
)
|
||||
|
||||
for src_rank in range(world_size):
|
||||
_, input_buffer, _, input_split_sizes = data[src_rank]
|
||||
input_indexes = self._size_cumsum(input_buffer.size(0), input_split_sizes, world_size)
|
||||
|
||||
output_buffer[output_indexes[src_rank]:output_indexes[src_rank + 1]].copy_(
|
||||
input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
|
||||
input_indexes = self._size_cumsum(
|
||||
input_buffer.size(0), input_split_sizes, world_size
|
||||
)
|
||||
|
||||
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, list[int], None], world_size: int) -> torch.Tensor:
|
||||
output_buffer[
|
||||
output_indexes[src_rank] : output_indexes[src_rank + 1]
|
||||
].copy_(
|
||||
input_buffer[
|
||||
input_indexes[dest_rank] : input_indexes[dest_rank + 1]
|
||||
]
|
||||
)
|
||||
|
||||
def _size_cumsum(
|
||||
self,
|
||||
buf_size: int,
|
||||
sizes: Union[torch.Tensor, list[int], None],
|
||||
world_size: int,
|
||||
) -> torch.Tensor:
|
||||
if sizes is None or len(sizes) == 0:
|
||||
sizes = torch.full(
|
||||
(world_size,), buf_size // world_size, dtype=torch.int64
|
||||
)
|
||||
sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64)
|
||||
if not isinstance(sizes, torch.Tensor):
|
||||
sizes = torch.tensor(sizes, dtype=torch.int64)
|
||||
assert sizes.dtype == torch.int64
|
||||
sizes = torch.cumsum(
|
||||
torch.cat(
|
||||
(
|
||||
torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes
|
||||
),
|
||||
dim=0
|
||||
(torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes),
|
||||
dim=0,
|
||||
),
|
||||
dim=0
|
||||
dim=0,
|
||||
)
|
||||
return sizes
|
||||
|
||||
|
||||
class AllReduce:
|
||||
def __init__(self, op):
|
||||
if op.op not in _reduce_ops:
|
||||
@ -127,7 +143,9 @@ class AllReduce:
|
||||
rank_0_device = data[0][i].device
|
||||
# collect all data to the list and make them
|
||||
# all on rank 0 device
|
||||
tensors = [data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))]
|
||||
tensors = [
|
||||
data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))
|
||||
]
|
||||
|
||||
# now mimic reduce across all ranks
|
||||
res = _reduce_ops[self.op](tensors)
|
||||
@ -186,6 +204,7 @@ class Gather:
|
||||
dest_tensor = out_tensor_list[rank]
|
||||
dest_tensor.copy_(src_in_tensor_list[0])
|
||||
|
||||
|
||||
class ReduceScatter:
|
||||
def __init__(self, op):
|
||||
if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
|
||||
@ -254,7 +273,8 @@ class Collective:
|
||||
|
||||
if rank == 0:
|
||||
self._start_cond.wait_for(
|
||||
lambda: self._count == self._world_size or self._pg._terminate.is_set()
|
||||
lambda: self._count == self._world_size
|
||||
or self._pg._terminate.is_set()
|
||||
)
|
||||
# SystemExit is not a subclass of Exception but BaseException
|
||||
# and can be distinguished from normal exception raised from program errors
|
||||
@ -265,7 +285,9 @@ class Collective:
|
||||
with self._done_cond:
|
||||
# wait for rank 0 to finish
|
||||
if rank > 0:
|
||||
self._done_cond.wait_for(lambda: self._done or self._pg._terminate.is_set())
|
||||
self._done_cond.wait_for(
|
||||
lambda: self._done or self._pg._terminate.is_set()
|
||||
)
|
||||
if self._pg._terminate.is_set():
|
||||
sys.exit("Test termination event occurs.")
|
||||
else:
|
||||
@ -287,14 +309,19 @@ class ProcessLocalGroup(dist.ProcessGroup):
|
||||
with cls._coll_lock:
|
||||
# pg_name is unique, we use that to record the mapping between pg and collective
|
||||
if pg.pg_name not in cls._cur_coll_on_pgs:
|
||||
cls._cur_coll_on_pgs[pg.pg_name] = Collective(pg.size(), collective, cls)
|
||||
cls._cur_coll_on_pgs[pg.pg_name] = Collective(
|
||||
pg.size(), collective, cls
|
||||
)
|
||||
return cls._cur_coll_on_pgs[pg.pg_name]
|
||||
|
||||
@classmethod
|
||||
def _end_coll(cls, collective, pg):
|
||||
# This is racily called by all ranks, so only one will work
|
||||
with cls._coll_lock:
|
||||
if pg.pg_name in cls._cur_coll_on_pgs and cls._cur_coll_on_pgs[pg.pg_name] == collective:
|
||||
if (
|
||||
pg.pg_name in cls._cur_coll_on_pgs
|
||||
and cls._cur_coll_on_pgs[pg.pg_name] == collective
|
||||
):
|
||||
cls._cur_coll_on_pgs.pop(pg.pg_name)
|
||||
|
||||
@classmethod
|
||||
@ -318,10 +345,13 @@ class ProcessLocalGroup(dist.ProcessGroup):
|
||||
input_buffer: torch.Tensor,
|
||||
output_split_sizes: Optional[list[int]],
|
||||
input_split_sizes: Optional[list[int]],
|
||||
opts=AllToAllOptions()
|
||||
opts=AllToAllOptions(),
|
||||
) -> torch.Tensor:
|
||||
coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
|
||||
res = coll.join(self._rank, (output_buffer, input_buffer, output_split_sizes, input_split_sizes))
|
||||
res = coll.join(
|
||||
self._rank,
|
||||
(output_buffer, input_buffer, output_split_sizes, input_split_sizes),
|
||||
)
|
||||
ProcessLocalGroup._end_coll(coll, self)
|
||||
return res
|
||||
|
||||
@ -380,21 +410,26 @@ class ProcessLocalGroup(dist.ProcessGroup):
|
||||
ProcessLocalGroup._end_coll(coll, self)
|
||||
return res
|
||||
|
||||
def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()):
|
||||
def _reduce_scatter_base(
|
||||
self, output_tensor, input_tensor, opts=ReduceScatterOptions()
|
||||
):
|
||||
tensor_list = list(torch.chunk(input_tensor, self._world_size))
|
||||
return self.reduce_scatter([output_tensor], [tensor_list], opts)
|
||||
|
||||
def reduce_scatter_tensor_coalesced(self, output_tensors, input_tensors, opts=ReduceScatterOptions()):
|
||||
def reduce_scatter_tensor_coalesced(
|
||||
self, output_tensors, input_tensors, opts=ReduceScatterOptions()
|
||||
):
|
||||
works = [
|
||||
self._reduce_scatter_base(output_tensor, input_tensor, opts)
|
||||
for output_tensor, input_tensor
|
||||
in zip(output_tensors, input_tensors)
|
||||
for output_tensor, input_tensor in zip(output_tensors, input_tensors)
|
||||
]
|
||||
for work in works[:-1]:
|
||||
work.wait()
|
||||
return works[-1]
|
||||
|
||||
def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()):
|
||||
def allgather_into_tensor_coalesced(
|
||||
self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()
|
||||
):
|
||||
res = None
|
||||
for o_t, i_t in zip(output_tensor_list, input_tensor_list):
|
||||
res = self._allgather_base(o_t, i_t)
|
||||
@ -470,7 +505,9 @@ class ThreadLocalWorld:
|
||||
|
||||
def _get_world(self) -> WorldData:
|
||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {})
|
||||
ThreadLocalWorld._world.world = WorldData(
|
||||
None, {}, {}, {}, {}, 0, {}, {}, {}
|
||||
)
|
||||
return ThreadLocalWorld._world.world
|
||||
|
||||
@property
|
||||
|
||||
@ -5,11 +5,13 @@ import enum
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.testing._internal.dist_utils as dist_utils
|
||||
from torch import Tensor, nn
|
||||
from torch import nn, Tensor
|
||||
from torch._jit_internal import Future
|
||||
from torch.distributed.nn import RemoteModule
|
||||
from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES
|
||||
from torch.distributed.nn.api.remote_module import _RemoteModule
|
||||
from torch.distributed.nn.api.remote_module import (
|
||||
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
|
||||
_RemoteModule,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
@ -35,16 +37,19 @@ def remote_module_attributes(remote_module):
|
||||
def remote_forward(remote_module, args):
|
||||
return remote_module.forward(*args)
|
||||
|
||||
|
||||
# RPC handler for running forward_async on the destination worker.
|
||||
def remote_forward_async(remote_module, args):
|
||||
# Since future cannot be pickled and sent over the RPC layer,
|
||||
# have to wait and behave just like ``forward_sync``.
|
||||
return remote_module.forward_async(*args).wait()
|
||||
|
||||
|
||||
# RPC handler for getting training mode on the destination worker.
|
||||
def get_remote_training_arg(module_rref):
|
||||
return module_rref.local_value().training
|
||||
|
||||
|
||||
class ModuleCreationMode(enum.Enum):
|
||||
MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
|
||||
MODULE_CTOR = "module_ctor"
|
||||
@ -147,7 +152,6 @@ class RemoteModuleTest(CommonRemoteModuleTest):
|
||||
):
|
||||
RemoteModule(remote_device, BadModule, args, kwargs).forward()
|
||||
|
||||
|
||||
@dist_utils.dist_init
|
||||
def test_forward_async(self):
|
||||
if self.rank != 0:
|
||||
@ -269,11 +273,19 @@ class RemoteModuleTest(CommonRemoteModuleTest):
|
||||
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
|
||||
):
|
||||
remote_module.train()
|
||||
ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
|
||||
ret1 = rpc.rpc_sync(
|
||||
dst_worker_name,
|
||||
get_remote_training_arg,
|
||||
args=(remote_module.get_module_rref(),),
|
||||
)
|
||||
self.assertEqual(ret1, True)
|
||||
|
||||
remote_module.eval()
|
||||
ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
|
||||
ret2 = rpc.rpc_sync(
|
||||
dst_worker_name,
|
||||
get_remote_training_arg,
|
||||
args=(remote_module.get_module_rref(),),
|
||||
)
|
||||
self.assertEqual(ret2, False)
|
||||
|
||||
@dist_utils.dist_init
|
||||
@ -466,7 +478,9 @@ class RemoteModuleTest(CommonRemoteModuleTest):
|
||||
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
|
||||
):
|
||||
with TemporaryFileName() as fname:
|
||||
with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"):
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.Error, "can only be pickled when using RPC"
|
||||
):
|
||||
torch.save(remote_module, fname)
|
||||
|
||||
|
||||
@ -556,9 +570,7 @@ class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
|
||||
)
|
||||
|
||||
args = (torch.ones(1), 2, "3")
|
||||
ret1 = rpc.rpc_sync(
|
||||
dst_worker1_name, remote_forward, (remote_module, args)
|
||||
)
|
||||
ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args))
|
||||
ret2 = rpc.rpc_sync(
|
||||
dst_worker2_name, remote_forward, (remote_module2, args)
|
||||
)
|
||||
@ -613,15 +625,15 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
||||
]
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
errorString = (r"HIP error: invalid device ordinal\n"
|
||||
r"HIP kernel errors might be asynchronously reported at some other API call, "
|
||||
r"so the stacktrace below might be incorrect.\n"
|
||||
r"For debugging consider passing AMD_SERIALIZE_KERNEL=3")
|
||||
errorString = (
|
||||
r"HIP error: invalid device ordinal\n"
|
||||
r"HIP kernel errors might be asynchronously reported at some other API call, "
|
||||
r"so the stacktrace below might be incorrect.\n"
|
||||
r"For debugging consider passing AMD_SERIALIZE_KERNEL=3"
|
||||
)
|
||||
else:
|
||||
errorString = r"CUDA error: invalid device ordinal"
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, errorString
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, errorString):
|
||||
[
|
||||
m.forward()
|
||||
for m in self._create_remote_module_iter(
|
||||
|
||||
@ -1,21 +1,26 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.nn as nn
|
||||
import torch.testing._internal.dist_utils
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.distributed.rpc import RRef
|
||||
from torch.testing._internal.common_utils import IS_MACOS, skip_but_pass_in_sandcastle_if
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
initialize_pg,
|
||||
@ -25,7 +30,6 @@ from torch.testing._internal.dist_utils import (
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
|
||||
|
||||
# Right now we test up to 3-layer nested rpc calls.
|
||||
@ -41,6 +45,7 @@ known_context_ids = set()
|
||||
|
||||
requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
|
||||
|
||||
|
||||
# Send rpc done info and context_id to
|
||||
# dst_rank = (self.rank + rank_distance) % self.world_size
|
||||
# we don't need a lock here since the GIL is held while executing remote
|
||||
@ -62,6 +67,7 @@ def _check_rpc_done(rank_distance):
|
||||
def _torch_ones(sizes, requires_grad=False):
|
||||
return torch.ones(sizes, requires_grad=requires_grad)
|
||||
|
||||
|
||||
# This method must be called on the rref owner, and verifies that the grad of
|
||||
# rref tensor equals to the given grad.
|
||||
def _compare_owner_value(context_id, rref, grad):
|
||||
@ -175,6 +181,7 @@ def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
|
||||
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
|
||||
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
|
||||
|
||||
|
||||
# This function is the same as _run_trainer, except rpc calls torchscript
|
||||
# function "my_script_ref_add" instead of python function "my_rref_add"
|
||||
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
|
||||
@ -231,9 +238,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
|
||||
|
||||
def _exec_func(self, exec_mode, method, *args):
|
||||
return self._exec_func_with_dst(
|
||||
self._next_rank(), exec_mode, method, *args
|
||||
)
|
||||
return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args)
|
||||
|
||||
def _next_rank(self):
|
||||
if hasattr(self, "dst_rank"):
|
||||
@ -286,15 +291,11 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
if ExecMode.RPC_SYNC == exec_mode:
|
||||
ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
|
||||
elif ExecMode.REMOTE == exec_mode:
|
||||
ret = rpc.remote(
|
||||
worker_name(dst_rank), fn, args=(t1, t2)
|
||||
).to_here()
|
||||
ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
|
||||
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
|
||||
# Verify graph for current context id.
|
||||
ctx = dist_autograd._current_context()
|
||||
@ -498,19 +499,13 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
t1 = torch.ones(3, 3, requires_grad=False)
|
||||
t2 = torch.zeros(3, 3, requires_grad=False)
|
||||
if ExecMode.RPC_SYNC == exec_mode:
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), torch.add, args=(t1, t2)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
|
||||
elif ExecMode.REMOTE == exec_mode:
|
||||
rpc.remote(
|
||||
worker_name(dst_rank), torch.add, args=(t1, t2)
|
||||
).to_here()
|
||||
rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
|
||||
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
|
||||
ctx = dist_autograd._current_context()
|
||||
send_functions = ctx._send_functions()
|
||||
@ -541,9 +536,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
tensors.append(tensor)
|
||||
dst_rank = self._next_rank()
|
||||
if ExecMode.RPC_SYNC == exec_mode:
|
||||
ret = rpc.rpc_sync(
|
||||
worker_name(dst_rank), torch.stack, args=(tensors,)
|
||||
)
|
||||
ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,))
|
||||
elif ExecMode.REMOTE == exec_mode:
|
||||
ret = rpc.remote(
|
||||
worker_name(dst_rank), torch.stack, args=(tensors,)
|
||||
@ -554,7 +547,9 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
self.assertEqual(torch.stack(tensors), ret)
|
||||
|
||||
# Verify appropriate tensors have been attached the autograd graph.
|
||||
next_funcs = next(iter(dist_autograd._current_context()._send_functions().values())).next_functions
|
||||
next_funcs = next(
|
||||
iter(dist_autograd._current_context()._send_functions().values())
|
||||
).next_functions
|
||||
for i in range(len(next_funcs)):
|
||||
self.assertEqual(
|
||||
"torch::autograd::AccumulateGrad", next_funcs[i][0].name()
|
||||
@ -585,9 +580,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
with dist_autograd.context() as context_id:
|
||||
for dst_rank in dst_ranks:
|
||||
rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
if nested:
|
||||
rpc.rpc_sync(
|
||||
worker_name(nested_dst_rank),
|
||||
@ -607,9 +600,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
def _backward_no_grad_on_tensor(self, t1, t2, sparse):
|
||||
with dist_autograd.context() as context_id:
|
||||
loss = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()),
|
||||
torch.add,
|
||||
args=(t1, t2))
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
if sparse:
|
||||
loss = torch.sparse.sum(loss)
|
||||
else:
|
||||
@ -650,11 +642,19 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
with dist_autograd.context() as context_id:
|
||||
if sparse:
|
||||
rref_t1 = rpc.remote(
|
||||
rref_owner, build_sparse_tensor, args=(False, True,)
|
||||
rref_owner,
|
||||
build_sparse_tensor,
|
||||
args=(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
rref_t1 = rpc.remote(
|
||||
rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
|
||||
rref_owner,
|
||||
_torch_ones,
|
||||
args=((3, 3),),
|
||||
kwargs={"requires_grad": True},
|
||||
)
|
||||
if callee == rref_owner:
|
||||
rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
|
||||
@ -707,10 +707,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
local_ret.sum().backward()
|
||||
|
||||
# create rref on self
|
||||
rref_t1 = rpc.remote(
|
||||
worker_name(self.rank),
|
||||
create_ref_fn,
|
||||
args=())
|
||||
rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=())
|
||||
|
||||
# kick off forward and backward pass on three other workers (trainers)
|
||||
rank_diffs = [1, 2, 3]
|
||||
@ -719,7 +716,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
worker_name((self.rank + rank_diff) % self.world_size),
|
||||
trainer_fn,
|
||||
args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
|
||||
) for rank_diff in rank_diffs
|
||||
)
|
||||
for rank_diff in rank_diffs
|
||||
]
|
||||
|
||||
# check if the trainers have done with their backward pass
|
||||
@ -877,9 +875,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
def _multiple_backward(self, t1, t2, sparse):
|
||||
with dist_autograd.context() as context_id:
|
||||
loss = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()),
|
||||
torch.add,
|
||||
args=(t1, t2))
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
if sparse:
|
||||
loss = torch.sparse.sum(loss)
|
||||
else:
|
||||
@ -924,9 +921,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
def _backward_simple(self, dst, t1, t2, local_grads, sparse):
|
||||
for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
|
||||
with dist_autograd.context() as context_id:
|
||||
ret = self._exec_func_with_dst(
|
||||
dst, exec_mode, torch.add, t1, t2
|
||||
)
|
||||
ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2)
|
||||
if sparse:
|
||||
loss = torch.sparse.sum(ret)
|
||||
else:
|
||||
@ -1005,7 +1000,6 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
|
||||
|
||||
|
||||
class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
|
||||
# Sparse tests only work with TensorPipeAgent.
|
||||
@dist_init
|
||||
def test_graph_for_builtin_call_sparse(self):
|
||||
@ -1081,7 +1075,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_no_grad_on_tensor(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1091,7 +1085,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1101,7 +1095,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1115,7 +1109,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1128,7 +1122,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1141,16 +1135,12 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps_sparse(self):
|
||||
self._test_trainer_ps(
|
||||
build_sparse_tensor,
|
||||
_run_trainer,
|
||||
True
|
||||
)
|
||||
self._test_trainer_ps(build_sparse_tensor, _run_trainer, True)
|
||||
|
||||
@dist_init
|
||||
def test_backward_multiple_round_trips_sparse(self):
|
||||
@ -1161,7 +1151,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
build_sparse_tensor(requires_grad=False),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
None,
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1169,7 +1159,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_different_dtypes(
|
||||
build_sparse_tensor(requires_grad=True, dtype=torch.float32),
|
||||
build_sparse_tensor(requires_grad=True, dtype=torch.float64),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1177,7 +1167,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_simple_python_udf(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1185,7 +1175,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_simple_script_call(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1193,7 +1183,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._nested_backward_accumulate_grads(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1202,7 +1192,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._backwards_nested_python_udf(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1210,7 +1200,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._mixed_requires_grad(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=False),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1218,7 +1208,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
|
||||
self._multiple_backward(
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
build_sparse_tensor(requires_grad=True),
|
||||
True
|
||||
True,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1348,17 +1338,13 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
if ExecMode.RPC_SYNC == exec_mode:
|
||||
ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
|
||||
elif ExecMode.REMOTE == exec_mode:
|
||||
ret = rpc.remote(
|
||||
worker_name(dst_rank), ret_requires_grad
|
||||
).to_here()
|
||||
ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
|
||||
|
||||
dist_autograd.backward(context_id, [ret.sum()])
|
||||
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
|
||||
# Wait for the prev rank to be done with rpc.
|
||||
self._check_rpc_done(1)
|
||||
@ -1421,9 +1407,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
t2 = torch.zeros(3, 3, requires_grad=False)
|
||||
for dst_rank in dst_ranks:
|
||||
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
# all worker_ids in dst_ranks should be recorded.
|
||||
ctx = dist_autograd._current_context()
|
||||
worker_ids = ctx._known_worker_ids()
|
||||
@ -1433,12 +1417,8 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
t1.requires_grad = True
|
||||
t2.requires_grad = True
|
||||
for dst_rank in dst_ranks:
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), torch.add, args=(t1, t2)
|
||||
)
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
# all worker_ids in dst_ranks should be recorded.
|
||||
worker_ids = ctx._known_worker_ids()
|
||||
self.assertEqual(worker_ids, dst_ranks)
|
||||
@ -1448,7 +1428,9 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand(3, 3, requires_grad=True)
|
||||
t2 = torch.rand(3, 3, requires_grad=True)
|
||||
loss = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)).sum()
|
||||
loss = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
).sum()
|
||||
with torch.autograd.profiler.profile() as p:
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
|
||||
@ -1485,7 +1467,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_no_grad_on_tensor(
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1495,7 +1477,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1505,7 +1487,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1518,7 +1500,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1532,7 +1514,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1545,16 +1527,12 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps(self):
|
||||
self._test_trainer_ps(
|
||||
create_tensor,
|
||||
_run_trainer,
|
||||
False
|
||||
)
|
||||
self._test_trainer_ps(create_tensor, _run_trainer, False)
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps_torchscript_functions(self):
|
||||
@ -1563,9 +1541,12 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
# ref as arg is passed to pybind boundary, and the ref is not garbage
|
||||
# collected by python when calling shutdown()
|
||||
import torch.distributed.rpc.api as api
|
||||
|
||||
api._ignore_rref_leak = True
|
||||
|
||||
self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False)
|
||||
self._test_trainer_ps(
|
||||
create_torchscript_tensor, _run_trainer_torchscript, False
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_backward_multiple_round_trips(self):
|
||||
@ -1576,7 +1557,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
torch.rand((3, 3)),
|
||||
torch.rand((3, 3), requires_grad=True),
|
||||
None,
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1646,9 +1627,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
|
||||
# We don't use the result of an RPC function, as a result the
|
||||
# backward pass would hang in the "FAST" mode.
|
||||
rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
|
||||
|
||||
val = torch.mul(t1, t2)
|
||||
|
||||
@ -1679,9 +1658,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
|
||||
# Run multiple round trips across different nodes and verify the
|
||||
# original node receives an error thrown on a node deep in the chain.
|
||||
val = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.add, args=(t2, t3)
|
||||
)
|
||||
val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3))
|
||||
val = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.mul, args=(val, t2)
|
||||
)
|
||||
@ -1710,9 +1687,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
res = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
|
||||
|
||||
# Wait for all RPCs to be done.
|
||||
dist.barrier()
|
||||
@ -1745,9 +1720,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
RuntimeError,
|
||||
f"Could not find autograd context with id: {context_id}",
|
||||
):
|
||||
res = rpc.rpc_sync(
|
||||
worker_name(self._next_rank()), torch.add, args=(t1, t2)
|
||||
)
|
||||
res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
|
||||
dist_autograd.backward(context_id, [res.sum()])
|
||||
|
||||
@dist_init
|
||||
@ -1768,7 +1741,6 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
@dist_init
|
||||
def test_backward_invalid_args(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
dist_autograd.backward(context_id, None)
|
||||
|
||||
@ -1817,7 +1789,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_different_dtypes(
|
||||
torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
|
||||
torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1825,7 +1797,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_simple_python_udf(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -1833,7 +1805,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._backward_simple_script_call(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1934,10 +1906,13 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
|
||||
# Mark rank 0 is done in the store, since the RPC framework on
|
||||
# some nodes might be broken at this point.
|
||||
store.set('test_backward_node_failure_python_udf_rank0_done', "True")
|
||||
store.set("test_backward_node_failure_python_udf_rank0_done", "True")
|
||||
else:
|
||||
# Wait for backward to finish on rank 0.
|
||||
store.wait(['test_backward_node_failure_python_udf_rank0_done'], timedelta(seconds=10))
|
||||
store.wait(
|
||||
["test_backward_node_failure_python_udf_rank0_done"],
|
||||
timedelta(seconds=10),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _nested_python_udf(t1, t2, dst):
|
||||
@ -1952,7 +1927,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._backwards_nested_python_udf(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
_test_clean_context_backward_context_id = None
|
||||
@ -2063,7 +2038,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._mixed_requires_grad(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=False),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
class TestDebugInfoFunc(Function):
|
||||
@ -2210,7 +2185,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._nested_backward_accumulate_grads(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init
|
||||
@ -2218,7 +2193,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
self._multiple_backward(
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
torch.rand(3, 3, requires_grad=True),
|
||||
False
|
||||
False,
|
||||
)
|
||||
|
||||
@dist_init(clean_shutdown=False)
|
||||
@ -2228,16 +2203,21 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
with dist_autograd.context() as context_id:
|
||||
loss = rpc.rpc_sync(
|
||||
f'worker{self._next_rank()}',
|
||||
f"worker{self._next_rank()}",
|
||||
DistAutogradTest._python_udf_with_backward_error,
|
||||
args=(t1, t2)).sum()
|
||||
args=(t1, t2),
|
||||
).sum()
|
||||
|
||||
try:
|
||||
# Run backward in a loop multiple times.
|
||||
for i in range(100):
|
||||
if i < 50:
|
||||
with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
|
||||
dist_autograd.backward(context_id, [loss], retain_graph=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Simulate error on backward pass"
|
||||
):
|
||||
dist_autograd.backward(
|
||||
context_id, [loss], retain_graph=True
|
||||
)
|
||||
elif i > 50:
|
||||
# Recovered from error.
|
||||
dist_autograd.backward(context_id, [loss], retain_graph=True)
|
||||
@ -2270,9 +2250,10 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
|
||||
@dist_init
|
||||
def test_no_grad_copy(self):
|
||||
'''
|
||||
"""
|
||||
Similar to test in test_autograd.py.
|
||||
'''
|
||||
"""
|
||||
|
||||
# create autograd function that saves grad pointer as class static
|
||||
class MyFunc(Function):
|
||||
static_grad_ptr = None
|
||||
@ -2302,7 +2283,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
@staticmethod
|
||||
def forward(ctx, inp1):
|
||||
ctx.size = inp1.size()
|
||||
return torch.tensor([1.])
|
||||
return torch.tensor([1.0])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
@ -2312,7 +2293,9 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
b = torch.randn(5, 6, requires_grad=True)
|
||||
# non-contiguous grad should be copied
|
||||
with dist_autograd.context() as context_id:
|
||||
dist_autograd.backward(context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))])
|
||||
dist_autograd.backward(
|
||||
context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]
|
||||
)
|
||||
grads = dist_autograd.get_gradients(context_id)
|
||||
self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
|
||||
self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
|
||||
@ -2516,9 +2499,7 @@ class DistAutogradTest(CommonDistAutogradTest):
|
||||
dist_autograd.backward(context_id, [loss])
|
||||
self.assertTrue(
|
||||
rpc.rpc_sync(
|
||||
dst,
|
||||
_compare_owner_value,
|
||||
args=(context_id, rref, t3.grad)
|
||||
dst, _compare_owner_value, args=(context_id, rref, t3.grad)
|
||||
)
|
||||
)
|
||||
|
||||
@ -2602,9 +2583,7 @@ class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
|
||||
with dist_autograd.context() as context_id:
|
||||
for dst_rank in dst_ranks:
|
||||
rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
|
||||
rpc.rpc_sync(
|
||||
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
|
||||
)
|
||||
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
|
||||
# the thread's context id should be cleaned up
|
||||
with self.assertRaises(RuntimeError):
|
||||
dist_autograd._retrieve_context(context_id)
|
||||
@ -2625,7 +2604,9 @@ class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
|
||||
|
||||
@dist_init
|
||||
def test_verify_backend_options(self):
|
||||
self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
|
||||
self.assertEqual(
|
||||
self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
|
||||
)
|
||||
self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
|
||||
self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
|
||||
self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
|
||||
@ -2645,7 +2626,6 @@ class WrapperModule(nn.Module):
|
||||
|
||||
|
||||
class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_device_maps_backward_pass(self):
|
||||
options = self.rpc_backend_options
|
||||
@ -2690,7 +2670,6 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_dist_autograd_sync_streams(self):
|
||||
|
||||
options = self.rpc_backend_options
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
|
||||
@ -2747,10 +2726,9 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
|
||||
local_layers = [l.to(0) for l in layers]
|
||||
remote_layers = [
|
||||
rpc.remote(
|
||||
worker_name(rank),
|
||||
WrapperModule,
|
||||
args=(layers[rank - 1], rank)
|
||||
) for rank in range(1, self.world_size)
|
||||
worker_name(rank), WrapperModule, args=(layers[rank - 1], rank)
|
||||
)
|
||||
for rank in range(1, self.world_size)
|
||||
]
|
||||
|
||||
x = torch.randn(5000, 2000).to(0)
|
||||
|
||||
@ -204,7 +204,9 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
||||
self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
|
||||
self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
|
||||
self._test_dist_optim_base(optim.SGD, lr=0.05)
|
||||
self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True)
|
||||
self._test_dist_optim_base(
|
||||
optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True
|
||||
)
|
||||
self._test_dist_optim_base(optim.Adadelta, rho=0.95)
|
||||
self._test_dist_optim_base(optim.RMSprop, lr=0.05)
|
||||
self._test_dist_optim_base(optim.Adamax, lr=0.05)
|
||||
|
||||
@ -12,12 +12,11 @@ import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
worker_name,
|
||||
from torch.testing._internal.dist_utils import dist_init, worker_name
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
|
||||
|
||||
|
||||
batch_size = 20
|
||||
in_features = 100
|
||||
@ -30,7 +29,6 @@ def timed_log(text):
|
||||
|
||||
|
||||
class BatchUpdateParameterServer:
|
||||
|
||||
def __init__(self, batch_update_size):
|
||||
self.model = nn.Linear(in_features, out_features)
|
||||
self.lock = threading.Lock()
|
||||
@ -54,7 +52,9 @@ class BatchUpdateParameterServer:
|
||||
else:
|
||||
p.grad += g
|
||||
with self.lock:
|
||||
timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates")
|
||||
timed_log(
|
||||
f"PS got {self.curr_update_size}/{self.batch_update_size} updates"
|
||||
)
|
||||
self.curr_update_size += 1
|
||||
fut = self.future_model
|
||||
|
||||
@ -72,7 +72,6 @@ class BatchUpdateParameterServer:
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, ps_rref):
|
||||
self.ps_rref = ps_rref
|
||||
self.loss_fn = nn.L1Loss()
|
||||
@ -107,18 +106,19 @@ def run_ps(trainers):
|
||||
timed_log("Start training")
|
||||
start = perf_counter()
|
||||
ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
|
||||
futs = [rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers]
|
||||
futs = [
|
||||
rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers
|
||||
]
|
||||
|
||||
torch.futures.wait_all(futs)
|
||||
stop = perf_counter()
|
||||
timed_log("Finish training")
|
||||
timed_log(f"Time spent training: {stop - start}s")
|
||||
|
||||
class ParameterServerTest(RpcAgentTestFixture):
|
||||
|
||||
class ParameterServerTest(RpcAgentTestFixture):
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_batch_updating_parameter_server(self):
|
||||
|
||||
if self.rank != 0:
|
||||
rpc.init_rpc(
|
||||
name=worker_name(self.rank),
|
||||
|
||||
@ -11,16 +11,19 @@ import torch.distributed.rpc as rpc
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.distributed.rpc import RRef, rpc_sync, rpc_async, remote
|
||||
from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from torch.testing._internal.dist_utils import dist_init, worker_name
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
|
||||
TOTAL_EPISODE_STEP = 5000
|
||||
GAMMA = 0.1
|
||||
SEED = 543
|
||||
|
||||
|
||||
def _call_method(method, rref, *args, **kwargs):
|
||||
r"""
|
||||
a helper function to call a method on the given RRef
|
||||
@ -43,6 +46,7 @@ class Policy(nn.Module):
|
||||
Copying the code to make these two examples independent.
|
||||
See https://github.com/pytorch/examples/tree/master/reinforcement_learning
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.affine1 = nn.Linear(4, 128)
|
||||
@ -67,6 +71,7 @@ class DummyEnv:
|
||||
tests in this file. It is designed to run for a set max number of iterations,
|
||||
returning random states and rewards at each step.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
|
||||
self.state_dim = state_dim
|
||||
self.num_iters = num_iters
|
||||
@ -96,6 +101,7 @@ class Observer:
|
||||
select an action. Then, the observer applies the action to its environment
|
||||
and reports the reward to the agent.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.id = rpc.get_worker_info().id
|
||||
self.env = DummyEnv()
|
||||
@ -171,8 +177,9 @@ class Agent:
|
||||
rpc_async(
|
||||
ob_rref.owner(),
|
||||
_call_method,
|
||||
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
|
||||
) for ob_rref in self.ob_rrefs
|
||||
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps),
|
||||
)
|
||||
for ob_rref in self.ob_rrefs
|
||||
]
|
||||
|
||||
# wait until all observers have finished this episode
|
||||
|
||||
@ -1,32 +1,36 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
wait_until_pending_futures_and_users_flushed,
|
||||
wait_until_owners_and_forks_on_rank,
|
||||
wait_until_pending_futures_and_users_flushed,
|
||||
worker_name,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
|
||||
def my_sleep_func(seconds=1):
|
||||
time.sleep(seconds)
|
||||
return torch.mul(torch.tensor(1), torch.tensor(1))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def my_script_func(tensor):
|
||||
return torch.add(tensor, tensor)
|
||||
|
||||
|
||||
def add_rref_to_value(rref, value):
|
||||
return rref.to_here() + value
|
||||
|
||||
class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
|
||||
class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
# no faulty_messages defined so this fails all retryable messages - see
|
||||
# faulty_rpc_agent_test_fixture.py for the list of retryable messages.
|
||||
@dist_init(messages_to_delay={})
|
||||
@ -36,22 +40,30 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
dst_worker_c = worker_name((self.rank + 2) % self.world_size)
|
||||
|
||||
# Worker0 sends RPC to Worker1 and creates an RRef there
|
||||
rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2)))
|
||||
rref = rpc.remote(
|
||||
dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))
|
||||
)
|
||||
# Worker0 sends an RPC to Worker2 with the RRef as an arg
|
||||
rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
|
||||
# check if the output is as expected
|
||||
self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2)))
|
||||
self.assertEqual(
|
||||
rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))
|
||||
)
|
||||
# explicitly delete all User RRefs
|
||||
_delete_all_user_and_unforked_owner_rrefs()
|
||||
|
||||
@dist_init
|
||||
def test_verify_backend_options(self):
|
||||
self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
|
||||
self.assertEqual(
|
||||
self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
|
||||
)
|
||||
self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
|
||||
self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
|
||||
self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
|
||||
self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
|
||||
self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
|
||||
self.assertEqual(
|
||||
self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC
|
||||
)
|
||||
|
||||
@dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
|
||||
def test_custom_faulty_messages(self):
|
||||
@ -66,7 +78,9 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
|
||||
@dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
|
||||
def test_custom_messages_to_delay(self):
|
||||
self.assertEqual(self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5})
|
||||
self.assertEqual(
|
||||
self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}
|
||||
)
|
||||
|
||||
def _test_remote_message_dropped_pickle(self, dst=None):
|
||||
if self.rank != 0:
|
||||
@ -95,7 +109,6 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
def test_remote_message_dropped_pickle_to_self(self):
|
||||
self._test_remote_message_dropped_pickle(self.rank)
|
||||
|
||||
|
||||
def _test_remote_message_dropped_timeout(self, func, args, dst=None):
|
||||
if self.rank != 0:
|
||||
return
|
||||
@ -297,22 +310,20 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
|
||||
|
||||
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
|
||||
fut = rpc.rpc_async(
|
||||
dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
fut.wait()
|
||||
|
||||
# Ensure that the currently set default timeout is large enough such
|
||||
# that RPCs with delays still complete.
|
||||
fut = rpc.rpc_async(
|
||||
dst_worker, my_script_func, args=(torch.tensor(1),)
|
||||
)
|
||||
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
|
||||
fut.wait()
|
||||
|
||||
# Ensure timeout if we set a new default and don't override
|
||||
rpc._set_rpc_timeout(0.001)
|
||||
fut = rpc.rpc_async(
|
||||
dst_worker, my_script_func, args=(torch.tensor(1),)
|
||||
)
|
||||
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
fut.wait()
|
||||
|
||||
|
||||
@ -6,13 +6,16 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
|
||||
# The following message types are currently retried in the RREF protocol and
|
||||
# distributed autograd. Thus only these messages should be tested with the
|
||||
# Faulty RPC Agent.
|
||||
retryable_message_types = ["RREF_FORK_REQUEST",
|
||||
"RREF_CHILD_ACCEPT",
|
||||
"RREF_USER_DELETE",
|
||||
"CLEANUP_AUTOGRAD_CONTEXT_REQ"]
|
||||
retryable_message_types = [
|
||||
"RREF_FORK_REQUEST",
|
||||
"RREF_CHILD_ACCEPT",
|
||||
"RREF_USER_DELETE",
|
||||
"CLEANUP_AUTOGRAD_CONTEXT_REQ",
|
||||
]
|
||||
|
||||
# The following messages incur the corresponding delay in seconds while being
|
||||
# processed in FaultyTensorPipeAgent's enqueueSend() function.
|
||||
@ -21,6 +24,7 @@ default_messages_to_delay = {
|
||||
"SCRIPT_CALL": 1.5, # Script/Builtin
|
||||
}
|
||||
|
||||
|
||||
class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -29,9 +33,7 @@ class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
|
||||
|
||||
@property
|
||||
def rpc_backend(self):
|
||||
return rpc.backend_registry.BackendType[
|
||||
"FAULTY_TENSORPIPE"
|
||||
]
|
||||
return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]
|
||||
|
||||
@property
|
||||
def rpc_backend_options(self):
|
||||
@ -54,7 +56,7 @@ class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
|
||||
error_regexes = [
|
||||
"Exception in thread pool task",
|
||||
"Connection reset by peer",
|
||||
"Connection closed by peer"
|
||||
"Connection closed by peer",
|
||||
]
|
||||
return "|".join([f"({error_str})" for error_str in error_regexes])
|
||||
|
||||
|
||||
@ -32,9 +32,8 @@ def fork_add(t1, t2, dst: str):
|
||||
class JitDistAutogradTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_get_gradients(self):
|
||||
|
||||
@torch.jit.script
|
||||
def dist_get_gradients(context_id: int) -> (dict[Tensor, Tensor]):
|
||||
def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]:
|
||||
return dist_autograd.get_gradients(context_id)
|
||||
|
||||
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import time
|
||||
import io
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -9,8 +9,9 @@ import torch.distributed as dist
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import Tensor
|
||||
from torch.autograd.profiler import record_function
|
||||
from torch.autograd.profiler_legacy import profile as _profile
|
||||
from torch.distributed.rpc import RRef
|
||||
from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key
|
||||
from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.dist_utils import (
|
||||
@ -23,11 +24,11 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
|
||||
from torch.autograd.profiler_legacy import profile as _profile
|
||||
|
||||
def rref_isinstance(rref, cls_to_check):
|
||||
return isinstance(rref.local_value(), cls_to_check)
|
||||
|
||||
|
||||
def sleep(t):
|
||||
time.sleep(t)
|
||||
|
||||
@ -140,10 +141,12 @@ def no_arg():
|
||||
def one_arg(value):
|
||||
return value + 1
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_add_ones(x):
|
||||
return torch.add(x, torch.ones(1))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_add_ones_with_record_function(x, block: str):
|
||||
with record_function(block):
|
||||
@ -154,16 +157,15 @@ def script_add_ones_with_record_function(x, block: str):
|
||||
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
|
||||
t: Tensor = torch.ones(1)
|
||||
with record_function(block):
|
||||
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
|
||||
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
|
||||
# Extra operator call to avoid de-duplication of the next async call
|
||||
# see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
|
||||
zero = torch.zeros_like(t)
|
||||
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
|
||||
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
|
||||
res = fut1.wait() + fut2.wait() + zero
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_fork_wait_udf(tensor):
|
||||
fut = torch.jit._fork(script_add_ones, tensor)
|
||||
@ -196,7 +198,9 @@ def script_fork_wait_throw(invalue):
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_worker_name: str) -> Tensor:
|
||||
def call_rpc_with_profiling(
|
||||
record: torch.classes.profiler._RecordFunction, dst_worker_name: str
|
||||
) -> Tensor:
|
||||
# Call rpc_async from within ScriptFunction and ensure that we can attach
|
||||
# profiling callbacks. Note that handle here is a Tensor representation of
|
||||
# RecordFunction.
|
||||
@ -205,9 +209,14 @@ def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor:
|
||||
fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block))
|
||||
def call_rpc_torchscript_with_record_function(
|
||||
dst_worker_name: str, block: str
|
||||
) -> Tensor:
|
||||
fut = rpc.rpc_async(
|
||||
dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)
|
||||
)
|
||||
return fut.wait()
|
||||
|
||||
|
||||
@ -311,9 +320,7 @@ class FutureTypingTest:
|
||||
def future_return_to_python(
|
||||
dst_rank: int, inputs: tuple[Tensor, Tensor]
|
||||
) -> Future[Tensor]:
|
||||
return rpc.rpc_async(
|
||||
f"worker{dst_rank}", two_args_two_kwargs, inputs
|
||||
)
|
||||
return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs)
|
||||
|
||||
fut_res = future_return_to_python(dst_rank, inputs)
|
||||
self.assertEqual(fut_res.wait(), expected_res)
|
||||
@ -524,6 +531,7 @@ def script_rpc_async_call(
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_sync_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
@ -531,6 +539,7 @@ def script_rpc_sync_call(
|
||||
res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return res
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_remote_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
@ -538,6 +547,7 @@ def script_rpc_remote_call(
|
||||
rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return rref_res.to_here()
|
||||
|
||||
|
||||
class JitRpcOpTest:
|
||||
# Call functions remotely from Script.
|
||||
@dist_init
|
||||
@ -550,10 +560,12 @@ class JitRpcOpTest:
|
||||
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
|
||||
kwargs = {}
|
||||
|
||||
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
|
||||
ret = script_op(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
for script_op in [
|
||||
script_rpc_async_call,
|
||||
script_rpc_sync_call,
|
||||
script_rpc_remote_call,
|
||||
]:
|
||||
ret = script_op(dst_worker_name, args, kwargs)
|
||||
self.assertEqual(ret, torch.tensor([10, 10]))
|
||||
|
||||
@dist_init
|
||||
@ -566,10 +578,12 @@ class JitRpcOpTest:
|
||||
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
|
||||
kwargs = {"first_kwarg": torch.tensor([2, 2])}
|
||||
|
||||
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
|
||||
ret = script_op(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
for script_op in [
|
||||
script_rpc_async_call,
|
||||
script_rpc_sync_call,
|
||||
script_rpc_remote_call,
|
||||
]:
|
||||
ret = script_op(dst_worker_name, args, kwargs)
|
||||
self.assertEqual(ret, torch.tensor([9, 9]))
|
||||
|
||||
@dist_init
|
||||
@ -584,10 +598,12 @@ class JitRpcOpTest:
|
||||
"first_kwarg": torch.tensor([2, 2]),
|
||||
"second_kwarg": torch.tensor([3, 3]),
|
||||
}
|
||||
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
|
||||
ret = script_op(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
for script_op in [
|
||||
script_rpc_async_call,
|
||||
script_rpc_sync_call,
|
||||
script_rpc_remote_call,
|
||||
]:
|
||||
ret = script_op(dst_worker_name, args, kwargs)
|
||||
self.assertEqual(ret, torch.tensor([8, 8]))
|
||||
|
||||
@dist_init
|
||||
@ -618,9 +634,7 @@ class JitRpcOpTest:
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
ret = script_rpc_async_call_with_assorted_types(
|
||||
dst_worker_name
|
||||
)
|
||||
ret = script_rpc_async_call_with_assorted_types(dst_worker_name)
|
||||
self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
|
||||
|
||||
@dist_init
|
||||
@ -639,9 +653,7 @@ class JitRpcOpTest:
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
ret = script_rpc_async_call_without_kwargs_passed(
|
||||
dst_worker_name
|
||||
)
|
||||
ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name)
|
||||
self.assertEqual(ret, 0)
|
||||
|
||||
@dist_init
|
||||
@ -659,9 +671,7 @@ class JitRpcOpTest:
|
||||
ret = fut.wait()
|
||||
return ret
|
||||
|
||||
ret = script_rpc_async_call_without_args_kwargs_passed(
|
||||
dst_worker_name
|
||||
)
|
||||
ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
|
||||
self.assertEqual(ret, 0)
|
||||
|
||||
@dist_init
|
||||
@ -730,9 +740,7 @@ class JitRpcOpTest:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Unknown keyword argument 'third_kwarg'"
|
||||
):
|
||||
ret = script_rpc_async_call_with_unexpected_kwarg(
|
||||
dst_worker_name
|
||||
)
|
||||
ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
|
||||
self.assertEqual(ret, 0)
|
||||
|
||||
@dist_init
|
||||
@ -915,9 +923,7 @@ class JitRpcTest(
|
||||
# Python 3.5 and Python 3.6 throw different error message, the only
|
||||
# common word can be greped is "pickle".
|
||||
with self.assertRaisesRegex(TypeError, "pickle"):
|
||||
rpc.rpc_async(
|
||||
dst_worker_name, my_local_script_module.forward, args=()
|
||||
)
|
||||
rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())
|
||||
|
||||
@dist_init
|
||||
def test_remote_script_module(self):
|
||||
@ -1005,9 +1011,7 @@ class JitRpcTest(
|
||||
rpc._disable_jit_rref_pickle()
|
||||
|
||||
out1 = rpc.rpc_sync(
|
||||
dst_name,
|
||||
load_script_module_with_pickled_rref,
|
||||
args=(f.getvalue(),)
|
||||
dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),)
|
||||
)
|
||||
out2 = m2()
|
||||
self.assertEqual(out1, out2)
|
||||
@ -1150,7 +1154,9 @@ class JitRpcTest(
|
||||
# After that, this test should be modified to validate the function time.
|
||||
events = prof.function_events
|
||||
function_event = get_function_event(events, prof_key)
|
||||
self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name)
|
||||
self.assertTrue(
|
||||
torch._jit_internal._qualified_name(one_arg) in function_event.name
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_rpc_async_jit_profiled(self):
|
||||
@ -1162,9 +1168,7 @@ class JitRpcTest(
|
||||
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
|
||||
kwargs = {}
|
||||
with _profile() as prof:
|
||||
script_rpc_async_call(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
script_rpc_async_call(dst_worker_name, args, kwargs)
|
||||
|
||||
# Ensure rpc_async call is profiled
|
||||
function_events = prof.function_events
|
||||
@ -1358,10 +1362,9 @@ class JitRpcTest(
|
||||
num = 20
|
||||
rrefs = [
|
||||
rpc.remote(
|
||||
dst1,
|
||||
async_add,
|
||||
args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
|
||||
) for i in range(num)
|
||||
dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
|
||||
)
|
||||
for i in range(num)
|
||||
]
|
||||
|
||||
for i in range(num):
|
||||
|
||||
@ -7,8 +7,8 @@ from torch import Tensor
|
||||
from torch.distributed.rpc import RRef
|
||||
from torch.testing._internal.dist_utils import (
|
||||
dist_init,
|
||||
wait_until_pending_futures_and_users_flushed,
|
||||
worker_name,
|
||||
wait_until_pending_futures_and_users_flushed
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
@ -64,14 +64,17 @@ def rpc_async_call_future_ret(
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return fut
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
|
||||
return rref_var.to_here()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
|
||||
return rref_var.to_here(timeout)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor:
|
||||
fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
|
||||
@ -84,6 +87,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
Run tests for rpc_async in JIT under the faulty agent test fixture to test
|
||||
arbitrary timeouts.
|
||||
"""
|
||||
|
||||
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
|
||||
def test_timeout_in_torchscript_function(self):
|
||||
# Call rpc_async + fut.wait() in torchscript function and ensure that
|
||||
@ -108,9 +112,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
# is less than the RPC takes to execute.
|
||||
rpc._set_rpc_timeout(0.001)
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
script_rpc_async_call(
|
||||
dst_worker_name, args, kwargs
|
||||
)
|
||||
script_rpc_async_call(dst_worker_name, args, kwargs)
|
||||
|
||||
# Ensure that we run to completion if zero timeout is specified.
|
||||
ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
|
||||
@ -198,7 +200,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
# Call RPC with RRef arg in JIT, which will go through JIT pickling and
|
||||
# ensure error is raised.
|
||||
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
|
||||
rpc_async_with_rref_arg(dst_worker, (rref, ))
|
||||
rpc_async_with_rref_arg(dst_worker, (rref,))
|
||||
|
||||
@dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
|
||||
def test_rref_timeout_pickle_script_func(self):
|
||||
@ -214,4 +216,4 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
|
||||
wait_until_pending_futures_and_users_flushed()
|
||||
# Call RPC with script function that takes RRef, ensure timeout during pickling
|
||||
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
|
||||
rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
|
||||
rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,27 +1,21 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.testing._internal.common_distributed import tp_transports
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
tp_transports,
|
||||
)
|
||||
|
||||
|
||||
class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
|
||||
@property
|
||||
def rpc_backend(self):
|
||||
return rpc.backend_registry.BackendType[
|
||||
"TENSORPIPE"
|
||||
]
|
||||
return rpc.backend_registry.BackendType["TENSORPIPE"]
|
||||
|
||||
@property
|
||||
def rpc_backend_options(self):
|
||||
return rpc.backend_registry.construct_rpc_backend_options(
|
||||
self.rpc_backend,
|
||||
init_method=self.init_method,
|
||||
_transports=tp_transports()
|
||||
self.rpc_backend, init_method=self.init_method, _transports=tp_transports()
|
||||
)
|
||||
|
||||
def get_shutdown_error_regex(self):
|
||||
|
||||
@ -6,9 +6,9 @@ import unittest
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
find_free_port,
|
||||
IS_SANDCASTLE,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
|
||||
CudaDdpComparisonTest,
|
||||
@ -21,15 +21,24 @@ from torch.testing._internal.distributed.nn.api.remote_module_test import (
|
||||
ThreeWorkersRemoteModuleTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.dist_autograd_test import (
|
||||
DistAutogradTest,
|
||||
CudaDistAutogradTest,
|
||||
DistAutogradTest,
|
||||
FaultyAgentDistAutogradTest,
|
||||
TensorPipeAgentDistAutogradTest,
|
||||
TensorPipeCudaDistAutogradTest
|
||||
TensorPipeCudaDistAutogradTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
|
||||
DistOptimizerTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.examples.parameter_server_test import (
|
||||
ParameterServerTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
|
||||
ReinforcementLearningRpcTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
|
||||
FaultyAgentRpcTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
|
||||
JitDistAutogradTest,
|
||||
)
|
||||
@ -40,18 +49,11 @@ from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
|
||||
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
RpcAgentTestFixture,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
|
||||
FaultyAgentRpcTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.rpc_test import (
|
||||
CudaRpcTest,
|
||||
RpcTest,
|
||||
TensorPipeAgentRpcTest,
|
||||
TensorPipeAgentCudaRpcTest,
|
||||
)
|
||||
from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ParameterServerTest
|
||||
from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
|
||||
ReinforcementLearningRpcTest,
|
||||
TensorPipeAgentRpcTest,
|
||||
)
|
||||
|
||||
|
||||
@ -61,15 +63,17 @@ def _check_and_set_tcp_init():
|
||||
# different ports.
|
||||
use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
|
||||
if use_tcp_init == "1":
|
||||
os.environ["MASTER_ADDR"] = '127.0.0.1'
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = str(find_free_port())
|
||||
|
||||
|
||||
def _check_and_unset_tcp_init():
|
||||
use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
|
||||
if use_tcp_init == "1":
|
||||
del os.environ["MASTER_ADDR"]
|
||||
del os.environ["MASTER_PORT"]
|
||||
|
||||
|
||||
# The tests for the RPC module need to cover multiple possible combinations:
|
||||
# - different aspects of the API, each one having its own suite of tests;
|
||||
# - different agents (ProcessGroup, TensorPipe, ...);
|
||||
@ -80,8 +84,10 @@ def _check_and_unset_tcp_init():
|
||||
# we call the generate_tests function of this file, passing to it a fixture for
|
||||
# the agent, which then gets mixed-in with each test suite.
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_DEV_DBG_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
"Skip ASAN as torch + multiprocessing spawn have known issues",
|
||||
)
|
||||
class SpawnHelper(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
@ -169,8 +175,10 @@ def generate_tests(
|
||||
for test_class in tests:
|
||||
if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
f'Skipping test {test_class} on sandcastle for the following reason: '
|
||||
'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr)
|
||||
f"Skipping test {test_class} on sandcastle for the following reason: "
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
name = f"{prefix}{test_class.__name__}"
|
||||
|
||||
Reference in New Issue
Block a user