[BE][lint] fix PYFMT for PT-D code under torch.testing._internal, add them to the lint list (#153114)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153114
Approved by: https://github.com/cyyever, https://github.com/fegin, https://github.com/H-Huang, https://github.com/Skylion007
This commit is contained in:
Xilun Wu
2025-05-07 17:42:14 -07:00
committed by PyTorch MergeBot
parent 2926dd4d8e
commit 0f9821d0e3
25 changed files with 1233 additions and 1254 deletions

View File

@ -1327,7 +1327,6 @@ exclude_patterns = [
'torch/testing/_internal/codegen/__init__.py',
'torch/testing/_internal/codegen/random_topo_test.py',
'torch/testing/_internal/common_cuda.py',
'torch/testing/_internal/common_distributed.py',
'torch/testing/_internal/common_jit.py',
'torch/testing/_internal/common_methods_invocations.py',
'torch/testing/_internal/common_modules.py',
@ -1344,38 +1343,6 @@ exclude_patterns = [
'torch/testing/_internal/data/network1.py',
'torch/testing/_internal/data/network2.py',
'torch/testing/_internal/dist_utils.py',
'torch/testing/_internal/distributed/__init__.py',
'torch/testing/_internal/distributed/_shard/__init__.py',
'torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py',
'torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py',
'torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py',
'torch/testing/_internal/distributed/_shard/test_common.py',
'torch/testing/_internal/distributed/_tensor/__init__.py',
'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
'torch/testing/_internal/distributed/distributed_test.py',
'torch/testing/_internal/distributed/distributed_utils.py',
'torch/testing/_internal/distributed/fake_pg.py',
'torch/testing/_internal/distributed/multi_threaded_pg.py',
'torch/testing/_internal/distributed/nn/__init__.py',
'torch/testing/_internal/distributed/nn/api/__init__.py',
'torch/testing/_internal/distributed/nn/api/remote_module_test.py',
'torch/testing/_internal/distributed/rpc/__init__.py',
'torch/testing/_internal/distributed/rpc/dist_autograd_test.py',
'torch/testing/_internal/distributed/rpc/dist_optimizer_test.py',
'torch/testing/_internal/distributed/rpc/examples/__init__.py',
'torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py',
'torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py',
'torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py',
'torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py',
'torch/testing/_internal/distributed/rpc/jit/__init__.py',
'torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py',
'torch/testing/_internal/distributed/rpc/jit/rpc_test.py',
'torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py',
'torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py',
'torch/testing/_internal/distributed/rpc/rpc_test.py',
'torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py',
'torch/testing/_internal/distributed/rpc_utils.py',
'torch/testing/_internal/generated/__init__.py',
'torch/testing/_internal/hypothesis_utils.py',
'torch/testing/_internal/inductor_utils.py',

View File

@ -5,6 +5,7 @@ import faulthandler
import itertools
import logging
import multiprocessing
import operator
import os
import queue
import subprocess
@ -21,37 +22,37 @@ from datetime import timedelta
from enum import Enum
from functools import partial, reduce, wraps
from io import StringIO
from typing import NamedTuple, Optional, Union, Any, Callable
from typing import Any, Callable, NamedTuple, Optional, Union
from unittest.mock import patch
from torch._logging._internal import trace_log
import torch
import torch._dynamo.test_case
import torch.cuda.nccl
import torch.distributed as c10d
import torch.nn as nn
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
import torch.nn as nn
from torch._logging._internal import trace_log
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
find_free_port,
IS_SANDCASTLE,
retry_on_connect_failures,
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_HPU,
TEST_WITH_ROCM,
TEST_WITH_TSAN,
TestCase,
run_tests,
TEST_HPU,
TEST_XPU,
TestCase,
)
from torch.testing._internal.distributed.multi_threaded_pg import (
_install_threaded_pg,
_uninstall_threaded_pg,
ProcessLocalGroup,
)
import operator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@ -178,10 +179,7 @@ def import_transformers_or_skip():
@wraps(func)
def wrapper(*args, **kwargs):
try:
from transformers import ( # noqa: F401
AutoModelForMaskedLM,
BertConfig,
)
from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401
return func(*args, **kwargs)
except ImportError:
@ -344,12 +342,14 @@ def requires_nccl():
"c10d was not compiled with the NCCL backend",
)
def requires_ucc():
return skip_but_pass_in_sandcastle_if(
not c10d.is_ucc_available(),
"c10d was not compiled with the UCC backend",
)
def requires_mpi():
return skip_but_pass_in_sandcastle_if(
not c10d.is_mpi_available(),
@ -425,7 +425,12 @@ def create_tcp_store(
)
else:
return c10d.TCPStore(
addr, port, world_size, is_master, wait_for_workers=wait_for_workers, use_libuv=use_libuv
addr,
port,
world_size,
is_master,
wait_for_workers=wait_for_workers,
use_libuv=use_libuv,
)
@ -433,7 +438,7 @@ if TEST_WITH_TSAN:
# TSAN runs much slower.
TIMEOUT_DEFAULT = 500
else:
TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300'))
TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300"))
TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400}
@ -444,9 +449,13 @@ if TEST_WITH_ROCM:
def create_device(interface=None, lazy_init: bool = False):
if sys.platform == "win32" or interface is None:
return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1", lazy_init=lazy_init)
return c10d.ProcessGroupGloo.create_device(
hostname="127.0.0.1", lazy_init=lazy_init
)
else:
return c10d.ProcessGroupGloo.create_device(interface=interface, lazy_init=lazy_init)
return c10d.ProcessGroupGloo.create_device(
interface=interface, lazy_init=lazy_init
)
def get_timeout(test_id) -> int:
@ -612,7 +621,9 @@ class MultiProcessTestCase(TestCase):
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
# or run the underlying test function.
def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
def __init__(
self, method_name: str = "runTest", methodName: str = "runTest"
) -> None:
# methodName is the correct naming in unittest and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != "runTest":
@ -622,10 +633,12 @@ class MultiProcessTestCase(TestCase):
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
if methodName != "runTest":
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e
raise ValueError(
f"no such test method in {self.__class__}: {methodName}"
) from e
def setUp(self) -> None:
super().setUp()
@ -660,7 +673,7 @@ class MultiProcessTestCase(TestCase):
args=(rank, self._current_test_name(), self.file_name, child_conn),
kwargs={
"fake_pg": getattr(self, "fake_pg", False),
}
},
)
process.start()
logger.info("Started process %s with pid %s", rank, process.pid)
@ -681,10 +694,10 @@ class MultiProcessTestCase(TestCase):
ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
if parent_pipe in ready_pipes:
if parent_pipe.closed:
logger.info(
"Pipe closed for process %s, stopping event listener thread", rank
"Pipe closed for process %s, stopping event listener thread",
rank,
)
return
@ -706,7 +719,9 @@ class MultiProcessTestCase(TestCase):
return
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None:
self = cls(test_name)
self.rank = rank
self.file_name = file_name
@ -734,14 +749,18 @@ class MultiProcessTestCase(TestCase):
getattr(self, test_name)()
except unittest.SkipTest as se:
logger.info(
"Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se)
"Process %s skipping test %s for following reason: %s",
self.rank,
test_name,
str(se),
)
sys.exit(TEST_SKIPS["generic"].exit_code)
except Exception:
logger.error(
"Caught exception: \n%s exiting "
"process %s with exit code: %s",
traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE
"Caught exception: \n%s exiting " "process %s with exit code: %s",
traceback.format_exc(),
self.rank,
MultiProcessTestCase.TEST_ERROR_EXIT_CODE,
)
# Send error to parent process.
parent_pipe.send(traceback.format_exc())
@ -773,7 +792,9 @@ class MultiProcessTestCase(TestCase):
pipes.append((i, pipe))
except ConnectionError as e:
logger.error(
"Encountered error while trying to get traceback for process %s: %s", i, e
"Encountered error while trying to get traceback for process %s: %s",
i,
e,
)
# Wait for results.
@ -783,7 +804,8 @@ class MultiProcessTestCase(TestCase):
if pipe.poll(5):
if pipe.closed:
logger.info(
"Pipe closed for process %s, cannot retrieve traceback", rank
"Pipe closed for process %s, cannot retrieve traceback",
rank,
)
continue
@ -797,7 +819,9 @@ class MultiProcessTestCase(TestCase):
)
except ConnectionError as e:
logger.error(
"Encountered error while trying to get traceback for process %s: %s", rank, e
"Encountered error while trying to get traceback for process %s: %s",
rank,
e,
)
def _join_processes(self, fn) -> None:
@ -807,7 +831,7 @@ class MultiProcessTestCase(TestCase):
try:
while True:
# check to see if any subprocess exited with an error early.
for (i, p) in enumerate(self.processes):
for i, p in enumerate(self.processes):
# This is the exit code processes exit with if they
# encountered an exception.
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
@ -866,7 +890,9 @@ class MultiProcessTestCase(TestCase):
"""
# If no processes are spawned, there is nothing to check.
if not self.processes:
logger.warning("Note: no subprocesses were spawned, test was likely skipped.")
logger.warning(
"Note: no subprocesses were spawned, test was likely skipped."
)
return
first_process = self.processes[0]
@ -912,7 +938,9 @@ class MultiProcessTestCase(TestCase):
# is some follow-up needed. Instead just "pass" the test
# with an appropriate message.
logger.info(
"Skipping %s on sandcastle for the following reason: %s", self.id(), skip.message
"Skipping %s on sandcastle for the following reason: %s",
self.id(),
skip.message,
)
return
else:
@ -927,13 +955,13 @@ class MultiProcessTestCase(TestCase):
def is_master(self) -> bool:
return self.rank == 0
# Utility base class for distributed Multi Process Test cases
# This abstracts the PG creation and deletion, the backends are selected based
# on device type. The tests functions can be instantiated per device type using
# common_device_type.instantiate_device_type_tests
# other backends can add entry in backend() function
class DistributedTestBase(MultiProcessTestCase):
def setUp(self):
super().setUp()
self._spawn_processes()
@ -947,11 +975,11 @@ class DistributedTestBase(MultiProcessTestCase):
def backend(self, device) -> str:
if "cuda" in device:
return "nccl"
elif "hpu" in device : # intel gaudi
elif "hpu" in device: # intel gaudi
return "hccl"
elif "xpu" in device:
return "xccl"
else :
else:
return "gloo"
def create_pg(self, device):
@ -961,7 +989,7 @@ class DistributedTestBase(MultiProcessTestCase):
backend=self.backend(device),
world_size=self.world_size,
rank=self.rank,
store=store
store=store,
)
if "nccl" in self.backend(device) or "xccl" in self.backend(device):
torch.accelerator.set_device_index(self.rank)
@ -971,6 +999,7 @@ class DistributedTestBase(MultiProcessTestCase):
num_visible_devices = torch.get_device_module(device).device_count()
return {i: [i % num_visible_devices] for i in range(self.world_size)}
def run_subtests(
cls_inst,
subtest_config: dict[str, list[Any]],
@ -1021,7 +1050,10 @@ def has_efa() -> bool:
try:
EFA_PROBE_RESULT = (
subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False).returncode == 0
subprocess.run(
["fi_info", "-p", "efa", "-t", "FI_EP_RDM"], check=False
).returncode
== 0
)
except FileNotFoundError:
EFA_PROBE_RESULT = False
@ -1049,7 +1081,6 @@ def spawn_threads_and_init_comms(
spawn_threads_and_init_comms, timeout=timeout, world_size=world_size
)
def _run_test_method_with_multi_threads(world_size, callback):
world = _install_threaded_pg()
global_store = c10d.HashStore()
@ -1066,7 +1097,9 @@ def spawn_threads_and_init_comms(
except BaseException as ex:
# Exceptions are handled in MultiThreadedTestCase
MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads
ProcessLocalGroup.exception_handle(
ex
) # trigger _terminate event and awaken worker threads
finally:
if world_is_valid():
c10d.destroy_process_group()
@ -1079,13 +1112,14 @@ def spawn_threads_and_init_comms(
return threads
@wraps(func)
def wrapper(self, *args, **kwargs):
# TODO: get test name from kwargs
torch._C._distributed_c10d._set_thread_isolation_mode(True)
try:
threads = _run_test_method_with_multi_threads(world_size, lambda: func(self, *args, **kwargs))
threads = _run_test_method_with_multi_threads(
world_size, lambda: func(self, *args, **kwargs)
)
# join and error handling
MultiThreadedTestCase._join_threads(threads, func)
finally:
@ -1108,6 +1142,7 @@ class MultiThreadedTestCase(TestCase):
No global state possible
How bad of a limitation is this?
"""
exception_queue = queue.Queue()
MAIN_THREAD_RANK = -1
@ -1122,7 +1157,9 @@ class MultiThreadedTestCase(TestCase):
return types.MethodType(wrapper, self)
def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None:
def __init__(
self, method_name: str = "runTest", methodName: str = "runTest"
) -> None:
# methodName is the correct naming in unittest and testslide uses keyword arguments.
# So we need to use both to 1) not break BC and, 2) support testslide.
if methodName != "runTest":
@ -1132,10 +1169,12 @@ class MultiThreadedTestCase(TestCase):
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
except AttributeError as e:
if methodName != 'runTest':
if methodName != "runTest":
# we allow instantiation with no explicit method name
# but not an *incorrect* or missing method name
raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e
raise ValueError(
f"no such test method in {self.__class__}: {methodName}"
) from e
def perThreadSetUp(self):
# super().setUp() # TestCase.setUp() calls torch.manual_seed()
@ -1180,7 +1219,9 @@ class MultiThreadedTestCase(TestCase):
raise RuntimeError("Invalid world")
for rank in range(self.world_size):
t = threading.Thread(target=self.__class__._run, args=(test_name, rank, self.world_size))
t = threading.Thread(
target=self.__class__._run, args=(test_name, rank, self.world_size)
)
t.start()
self.threads.append(t)
@ -1205,7 +1246,10 @@ class MultiThreadedTestCase(TestCase):
Run the current test associated with `test_name` using the threaded process group.
"""
c10d.init_process_group(
backend="threaded", rank=rank, world_size=world_size, store=self.__class__.global_store
backend="threaded",
rank=rank,
world_size=world_size,
store=self.__class__.global_store,
)
self.perThreadSetUp()
@ -1213,12 +1257,13 @@ class MultiThreadedTestCase(TestCase):
getattr(self, test_name)()
except BaseException as ex:
self.exception_queue.put((rank, sys.exc_info()))
ProcessLocalGroup.exception_handle(ex) # trigger _terminate event and awaken worker threads
ProcessLocalGroup.exception_handle(
ex
) # trigger _terminate event and awaken worker threads
finally:
c10d.destroy_process_group()
self.perThreadTearDown()
@classmethod
def _join_threads(cls, threads, fn):
timeout = TIMEOUT_DEFAULT
@ -1262,7 +1307,10 @@ class MultiThreadedTestCase(TestCase):
exc = exc_info[1]
if isinstance(exc, unittest.SkipTest):
logger.info(
"Thread %s skipping test %s for following reason: %s", rank, fn, str(exc)
"Thread %s skipping test %s for following reason: %s",
rank,
fn,
str(exc),
)
if skip_code < 0:
skip_code = TEST_SKIPS["generic"].exit_code
@ -1272,12 +1320,8 @@ class MultiThreadedTestCase(TestCase):
raise RuntimeError(msg)
elif isinstance(exc, Exception):
msg = "".join(traceback.format_exception(*exc_info))
logger.error(
"Caught exception: \n%s exiting thread %s", msg, rank
)
error_msg += (
f"Thread {rank} exited with exception:\n{msg}\n"
)
logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
elif isinstance(exc, SystemExit):
if type(exc.code) == int and skip_code < 0:
skip_code = exc.code
@ -1292,7 +1336,9 @@ class MultiThreadedTestCase(TestCase):
if IS_SANDCASTLE:
# "pass" the test with an appropriate message.
logger.info(
"Skipping %s on sandcastle for the following reason: %s", fn, skip.message
"Skipping %s on sandcastle for the following reason: %s",
fn,
skip.message,
)
return
else:
@ -1352,14 +1398,15 @@ class SaveForwardInputsModel(nn.Module):
self.forward_inputs[self] = x
return self.c2(self.c1(x))
@contextmanager
def _dynamo_dist_per_rank_init(rank, world_size, init_pg=True, fake_pg=False):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
if not fake_pg:
torch.accelerator.set_device_index(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
if init_pg:
if fake_pg:
store = torch.testing._internal.distributed.fake_pg.FakeStore()
@ -1424,6 +1471,7 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
Prefer MultiThreadedTestCase for most tests. Perhaps use this one
sparingly for integration tests.
"""
def setUp(self):
super().setUp()
self._spawn_processes()
@ -1440,7 +1488,9 @@ class DynamoDistributedMultiProcTestCase(MultiProcessTestCase):
return torch.cuda.device_count()
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs) -> None:
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
) -> None:
trace_log.addHandler(logging.NullHandler())
# The rest is copypasta from MultiProcessTestCase._run
@ -1461,7 +1511,6 @@ class MultiProcContinousTest(TestCase):
# timeout configured per class
timeout: timedelta = timedelta(seconds=120)
@classmethod
@abc.abstractmethod
def backend_str(cls) -> str:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import sys
from functools import wraps, partial
from functools import partial, wraps
import torch
import torch.distributed as dist
@ -12,8 +12,10 @@ from torch.testing._internal.common_distributed import (
tp_transports,
)
TEST_GPU_NUM = 4
class ShardedTensorTestBase(MultiProcessTestCase):
@property
def world_size(self):
@ -34,9 +36,10 @@ class ShardedTensorTestBase(MultiProcessTestCase):
if backend == "nccl":
torch.cuda.set_device(self.rank)
def init_rpc(self):
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports())
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
_transports=tp_transports()
)
rpc_backend_options.init_method = f"file://{self.file_name}"
for rank in range(self.world_size):
rpc_backend_options.set_device_map(
@ -79,6 +82,7 @@ class ShardedTensorTestBase(MultiProcessTestCase):
self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
# wrapper to initialize comms (processgroup + rpc)
def with_comms(func=None, init_rpc=True, backend="nccl"):
if func is None:
@ -95,4 +99,5 @@ def with_comms(func=None, init_rpc=True, backend="nccl"):
self.init_comms(init_rpc=init_rpc, backend=backend)
func(self, *args, **kwargs)
self.destroy_comms(destroy_rpc=init_rpc)
return wrapper

View File

@ -121,16 +121,17 @@ def clone_module_parameter(module, param_name):
tensor = getattr(module, param_name)
return torch.nn.Parameter(tensor.detach().clone())
def gen_binary_op_func(python_op, inplace=False):
src_lines = ['def f(lhs, rhs):']
if "torch" in python_op:
src_lines.append(f' return {python_op}(lhs, rhs)\n')
elif inplace:
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
else:
src_lines.append(f' return lhs {python_op} rhs\n')
code_str = '\n'.join(src_lines)
g = {'torch': torch}
def gen_binary_op_func(python_op, inplace=False):
src_lines = ["def f(lhs, rhs):"]
if "torch" in python_op:
src_lines.append(f" return {python_op}(lhs, rhs)\n")
elif inplace:
src_lines.append(f" lhs {python_op}= rhs\n return lhs\n")
else:
src_lines.append(f" return lhs {python_op} rhs\n")
code_str = "\n".join(src_lines)
g = {"torch": torch}
builtins.exec(code_str, g)
return g["f"]

View File

@ -2,12 +2,11 @@
import copy
import random
import torch
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
PLACEMENTS = [
"rank:0/cuda:0",
@ -31,13 +30,9 @@ def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
)
return spec_list
class MyShardedModel2(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor2 = sharded_tensor.rand(
@ -49,12 +44,7 @@ class MyShardedModel2(torch.nn.Module):
class MyShardedModel1(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
def __init__(self, spec=None, group=None, init_rrefs=True) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor1 = sharded_tensor.rand(

View File

@ -2,7 +2,6 @@
import torch
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor

View File

@ -4,22 +4,16 @@
import itertools
import sys
from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from functools import partial, wraps
from typing import (
Any,
Callable,
cast,
TypeVar,
Union,
Optional,
)
from collections.abc import Iterator, Sequence
from typing import Any, Callable, cast, Optional, TypeVar, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch._utils import _get_device_module
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.tensor.parallel import (
@ -29,21 +23,16 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_utils import (
TEST_HPU,
TEST_CUDA,
TEST_XPU
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
MultiThreadedTestCase,
skip_if_lt_x_gpu,
run_subtests,
skip_if_lt_x_gpu,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torch._utils import _get_device_module
if TEST_CUDA:
DEVICE_TYPE = "cuda"
@ -232,20 +221,31 @@ class Transformer(nn.Module):
@staticmethod
def parallelize(
module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: bool = False
module: "Transformer",
device_mesh: DeviceMesh,
use_seq_parallel: bool,
local_output_for_attn: bool = False,
) -> nn.Module:
assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
# Parallelize the root submodules.
if use_seq_parallel:
root_plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)),
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"pos_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(0)
),
"norm": SequenceParallel(),
}
else:
root_plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Replicate()
),
"pos_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Replicate()
),
}
module_tp = parallelize_module(module, device_mesh, root_plan)
@ -260,9 +260,15 @@ class Transformer(nn.Module):
# shard the RMSNorms
layer_parallelize_plan["attention_norm"] = SequenceParallel()
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
layer_parallelize_plan["attention.wq"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wk"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wv"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wq"] = ColwiseParallel(
use_local_output=local_output_for_attn
)
layer_parallelize_plan["attention.wk"] = ColwiseParallel(
use_local_output=local_output_for_attn
)
layer_parallelize_plan["attention.wv"] = ColwiseParallel(
use_local_output=local_output_for_attn
)
layer_parallelize_plan["attention.wo"] = (
RowwiseParallel(output_layouts=Shard(1))
if use_seq_parallel
@ -297,7 +303,9 @@ class Transformer(nn.Module):
if local_output_for_attn:
for layer in module_tp.layers:
layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size()
layer.attention.n_heads = (
module_tp.model_args.n_heads // device_mesh.size()
)
# Manually set output.weight so that parameters and gradients are shared.
if module_tp.model_args.weight_tying:
@ -327,7 +335,10 @@ class DTensorTestBase(MultiProcessTestCase):
@property
def device_type(self) -> str:
# if enough GPU we can use GPU, otherwise we fallback to CPU
if not (TEST_CUDA or TEST_XPU) or torch.accelerator.device_count() < self.world_size:
if (
not (TEST_CUDA or TEST_XPU)
or torch.accelerator.device_count() < self.world_size
):
return "cpu"
else:
return DEVICE_TYPE
@ -344,7 +355,14 @@ class DTensorTestBase(MultiProcessTestCase):
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl", "hccl", "xccl"]:
if self.backend not in [
"nccl",
"gloo",
"mpi",
"cpu:gloo,cuda:nccl",
"hccl",
"xccl",
]:
raise RuntimeError(f"Backend {self.backend} not supported!")
device_id = None
@ -352,7 +370,9 @@ class DTensorTestBase(MultiProcessTestCase):
# set device for nccl pg for collectives
torch.accelerator.set_device_index(self.rank)
# we only need to set device_id for nccl backend with eager init
device_id = torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
device_id = (
torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
)
# For nccl backend, bind the device to the process if device_id is not None
# so the nccl communicator is immediately formed and we can use `ncclCommSplit`
# for form subgroup to avoid unnecesssary overhead.
@ -371,7 +391,9 @@ class DTensorTestBase(MultiProcessTestCase):
# FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
# test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
if device_id is None:
device_id = torch.cuda.current_device() if self.device_type == "cuda" else self.rank
device_id = (
torch.cuda.current_device() if self.device_type == "cuda" else self.rank
)
dist.barrier(device_ids=[device_id])
dist.destroy_process_group()
@ -398,14 +420,11 @@ TestFunc = Callable[[...], object]
# wrapper to initialize comms (processgroup)
def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
def decorator(func, eager_init: bool = False):
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
) -> None:
self.init_pg(eager_init)
try:
@ -418,7 +437,11 @@ def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
return wrapper
return decorator(func=eager_init) if callable(eager_init) else partial(decorator, eager_init=eager_init)
return (
decorator(func=eager_init)
if callable(eager_init)
else partial(decorator, eager_init=eager_init)
)
class DTensorOpTestBase(MultiThreadedTestCase):
@ -459,10 +482,16 @@ class DTensorConverter:
self.flatten_kwargs: list[object] = flatten_kwargs
self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
choices_for_args = [self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_args if isinstance(arg, torch.Tensor)]
choices_for_args = [
self.gen_sharding_choices_for_arg(arg)
for arg in self.flatten_args
if isinstance(arg, torch.Tensor)
]
choices_for_args.extend(
self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_kwargs if isinstance(arg, torch.Tensor)
self.gen_sharding_choices_for_arg(arg)
for arg in self.flatten_kwargs
if isinstance(arg, torch.Tensor)
)
self.sharding_combs: Iterator[Sequence[Placement]] = iter(

View File

@ -20,7 +20,7 @@ from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
)
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.dist_utils import dist_init, INIT_METHOD_TEMPLATE
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
@ -68,7 +68,7 @@ gLogger = init_logger()
class FeatureSet(NamedTuple):
""" A feature set has 2 types of features"""
"""A feature set has 2 types of features"""
dense_features: torch.Tensor
sparse_features: torch.LongTensor
@ -210,7 +210,8 @@ class Trainer:
gLogger.info(
"Succeeded in creating a HybridModel instance with "
"%s ddp params and %s other local params.",
len(self.ddp_params), len(self.non_ddp_params)
len(self.ddp_params),
len(self.non_ddp_params),
)
def destroy_pg(self):
@ -246,7 +247,8 @@ class Trainer:
gLogger.info(
"Trainer reduced input patches from %s "
"to %s to simulate uneven inputs.",
len(batches), len(input_batches)
len(batches),
len(input_batches),
)
else:
input_batches = batches
@ -260,7 +262,11 @@ class Trainer:
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(
"Loss is %s for mini batch: %s. "
"Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), grads_dict
"Grads dict has %s entries: %s",
loss,
mini_batch,
len(grads_dict),
grads_dict,
)
return (
tuple(grads_dict[param] for param in self.ddp_params),
@ -348,7 +354,9 @@ class DdpUnderDistAutogradTest(RpcAgentTestFixture):
def _trainer_process(self, rank: int):
gLogger.info("Running the trainer #%s...", rank)
gLogger.info(
"Initing trainer process group by trainer #%s with ranks %s", rank, TRAINER_RANKS
"Initing trainer process group by trainer #%s with ranks %s",
rank,
TRAINER_RANKS,
)
dist.init_process_group(
backend="gloo",
@ -534,7 +542,9 @@ class DdpComparisonTest(CommonDdpComparisonTest):
inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
if simulate_uneven_inputs:
gLogger.info("Rank %s training with %s inputs.", self.rank, len(inputs_list))
gLogger.info(
"Rank %s training with %s inputs.", self.rank, len(inputs_list)
)
# Use distributed autograd. The gradients will be in RPC context map.
grads_dict = {}

View File

@ -1,96 +1,93 @@
# mypy: allow-untyped-defs
import copy
import json
import itertools
import json
import math
import operator
import os
import random
import sys
import tempfile
import time
from collections import namedtuple, OrderedDict, defaultdict
import unittest
from collections import defaultdict, namedtuple, OrderedDict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from functools import reduce
from typing import Union, NamedTuple, Callable, Any
import unittest
from typing import Any, Callable, NamedTuple, Union
import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
import torch.nn as nn
import torch.nn.functional as F
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
from torch.utils._python_dispatch import TorchDispatchMode
from torch._utils_internal import (
TEST_MASTER_ADDR as MASTER_ADDR,
TEST_MASTER_PORT as MASTER_PORT,
)
from torch.autograd import DeviceType
from torch.cuda.amp import GradScaler, autocast
from torch.cuda.amp import autocast, GradScaler
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
post_localSGD_hook as post_localSGD,
powerSGD_hook as powerSGD,
default_hooks as default,
quantization as quantization_hooks,
)
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.distributed.distributed_c10d import (
get_world_size,
_get_default_group,
_get_pg_config,
get_world_size,
)
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.distributed.utils import (
_verify_param_shape_across_processes,
_sync_module_states,
_verify_param_shape_across_processes,
)
from torch.profiler import (
ExecutionTraceObserver,
ProfilerActivity,
)
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
from torch.profiler import ExecutionTraceObserver, ProfilerActivity
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS,
captured_output,
cleanup_temp_dir,
DistTestCases,
init_multigpu_helper,
initialize_temp_directories,
cleanup_temp_dir,
simple_sparse_reduce_tests,
skip_if_rocm_multiprocess,
skip_if_small_worldsize,
skip_if_odd_worldsize,
skip_if_lt_x_gpu,
MultiProcessTestCase,
nccl_skip_if_lt_x_gpu,
skip_if_no_gpu,
require_n_gpus_for_nccl_backend,
requires_nccl_version,
captured_output,
with_nccl_blocking_wait,
with_dist_debug_levels,
simple_sparse_reduce_tests,
skip_if_lt_x_gpu,
skip_if_no_gpu,
skip_if_odd_worldsize,
skip_if_rocm_multiprocess,
skip_if_small_worldsize,
TEST_SKIPS,
verify_ddp_error_logged,
DistTestCases,
with_dist_debug_levels,
with_nccl_blocking_wait,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_MACOS,
IS_WINDOWS,
FILE_SCHEMA,
instantiate_parametrized_tests,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
)
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.data.distributed import DistributedSampler
import operator
try:
import torchvision
@ -201,19 +198,19 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
else profiler.function_events
)
return [
event for event in event_list
event
for event in event_list
if (
(event.name.endswith(event_name) or event.name.startswith(event_name))
and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
)
]
def get_profiler_nccl_meta(prof):
"""Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
We will need to test metadata obtained from profiler here"""
tf = tempfile.NamedTemporaryFile(
mode="w+t", suffix=".json", delete=False
)
tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False)
tf.close()
trace_file = tf.name
@ -227,6 +224,7 @@ def get_profiler_nccl_meta(prof):
return [e for e in events if e.get("name") == "record_param_comms"]
# Base error message substring on unfinished reductions.
ddp_prev_reduction_unfinished_str = (
"Expected to have finished reduction in the prior iteration"
@ -424,6 +422,7 @@ CUSTOM_PG_TIMEOUT = {
"test_ddp_has_finalized": 5,
}
def require_backend_is_available(backends):
def check(backend):
if backend == dist.Backend.GLOO:
@ -672,7 +671,7 @@ class DistributedTest:
# Verify buffers across ranks.
m1_buffers = list(m1.buffers())
m2_buffers = list(m2.buffers())
for (buf1, buf2) in zip(m1_buffers, m2_buffers):
for buf1, buf2 in zip(m1_buffers, m2_buffers):
gathered_bufs = [
torch.empty_like(buf1) for _ in range(dist.get_world_size())
]
@ -1218,8 +1217,8 @@ class DistributedTest:
subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
real_group_ranks_res1 = _get_pg_config(subgroup1)['ranks']
real_group_ranks_res2 = _get_pg_config(subgroup2)['ranks']
real_group_ranks_res1 = _get_pg_config(subgroup1)["ranks"]
real_group_ranks_res2 = _get_pg_config(subgroup2)["ranks"]
expect_group_ranks_res1 = (
rank // subgroup_size1 * subgroup_size1
@ -1269,7 +1268,7 @@ class DistributedTest:
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
"Coalescing manager currently tests with NCCL only; internal test flaky"
"Coalescing manager currently tests with NCCL only; internal test flaky",
)
def test_coalescing_manager(self):
self._barrier()
@ -1294,7 +1293,7 @@ class DistributedTest:
for i in range(num_colls):
self.assertEqual(
small_tensors[i],
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
)
self._barrier()
@ -1303,7 +1302,7 @@ class DistributedTest:
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
"Coalescing manager currently tests with NCCL only; internal test flaky"
"Coalescing manager currently tests with NCCL only; internal test flaky",
)
def test_coalescing_manager_async(self):
self._barrier()
@ -1329,7 +1328,7 @@ class DistributedTest:
for i in range(num_colls):
self.assertEqual(
small_tensors[i],
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
big_tensor[i * size_per_coll : (i + 1) * size_per_coll],
)
self._barrier()
@ -1585,7 +1584,9 @@ class DistributedTest:
backend = dist.get_backend()
if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
for event_name in [f"{backend}:send", f"{backend}:recv"]:
events = get_profiling_event(event_name, prof, dedup_gpu_user_annotation=True)
events = get_profiling_event(
event_name, prof, dedup_gpu_user_annotation=True
)
self.assertTrue(events)
# Event order is not deterministic, so simply assert their shape
# is found in the following list.
@ -1595,7 +1596,6 @@ class DistributedTest:
for event in events:
self.assertTrue(event.input_shapes in expected_shapes)
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
@requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
@ -2745,9 +2745,7 @@ class DistributedTest:
]
_group, group_id, _rank = self._init_global_test()
for unsupported_op in unsupported_ops:
with self.assertRaisesRegex(
ValueError, "all_reduce does not support"
):
with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
dist.all_reduce(
_build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
)
@ -4373,9 +4371,7 @@ class DistributedTest:
self.net2 = nn.Linear(10, 0)
model = ToyModel().to(self.rank)
nn.parallel.DistributedDataParallel(
model, device_ids=[self.rank]
)
nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
@skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
def test_ddp_create_graph(self):
@ -4521,7 +4517,7 @@ class DistributedTest:
models_to_test.append(
(torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
)
for (model, inp) in models_to_test:
for model, inp in models_to_test:
# Enable determinism in cudnn operators
with torch.backends.cudnn.flags(
enabled=True, deterministic=True, benchmark=False
@ -4721,11 +4717,11 @@ class DistributedTest:
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
model, params_to_ignore
)
torch.nn.parallel.DistributedDataParallel(
model, device_ids=[self.rank]
)
dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
model, named_params=True
torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.rank])
dp_params = (
torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
model, named_params=True
)
)
for name, _ in dp_params:
self.assertNotEqual(f"module.{params_to_ignore[0]}", name)
@ -4734,7 +4730,11 @@ class DistributedTest:
# no of parameters.
num_ddp_params = len(list(model.parameters())) - 1
count = 0
dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(model, named_params=False)
dp_params = (
torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
model, named_params=False
)
)
for _ in dp_params:
count += 1
self.assertEqual(count, num_ddp_params)
@ -4902,7 +4902,8 @@ class DistributedTest:
# Parameters to ignore are in the format {module_name}.{param_name}
to_ignore = ["a.weight", "buffer"]
torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
model, to_ignore,
model,
to_ignore,
)
mp_config = self._get_fp16_config()
net = torch.nn.parallel.DistributedDataParallel(
@ -4915,11 +4916,11 @@ class DistributedTest:
expected_ignored = len(to_ignore)
n_ignored = 0
# ignored params should not have _mp_param or _fp_param fields.
for (n, p) in itertools.chain(net.named_parameters(), net.named_buffers()):
for n, p in itertools.chain(net.named_parameters(), net.named_buffers()):
if n in to_ignore:
n_ignored += 1
self.assertFalse(hasattr(p, '_mp_param'))
self.assertFalse(hasattr(p, '_fp_param'))
self.assertFalse(hasattr(p, "_mp_param"))
self.assertFalse(hasattr(p, "_fp_param"))
else:
self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
self.assertEqual(torch.float32, p._fp_param.dtype)
@ -4940,10 +4941,8 @@ class DistributedTest:
def __init__(self) -> None:
super().__init__()
self.m = torch.nn.Linear(1, 5)
self.register_buffer('buffer', torch.randn(1, 2))
self.p = torch.nn.Parameter(
torch.randn(10, 5), requires_grad=False
)
self.register_buffer("buffer", torch.randn(1, 2))
self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False)
def forward(self_, x): # noqa: B902
params = self_.m.parameters()
@ -4978,7 +4977,7 @@ class DistributedTest:
for n, param in net.named_parameters():
self.assertEqual(param.dtype, torch.float32)
if param.grad is None:
assert n == 'module.p' # Only param that doesn't require grad
assert n == "module.p" # Only param that doesn't require grad
else:
self.assertEqual(param.grad.dtype, torch.float32)
tensor_list = [
@ -4994,7 +4993,9 @@ class DistributedTest:
net.zero_grad(set_to_none=set_grad_to_none)
@skip_if_lt_x_gpu(2)
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(self):
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(
self,
):
self._test_ddp_native_mixed_precision(
gradient_as_bucket_view=False,
set_grad_to_none=False,
@ -5014,7 +5015,9 @@ class DistributedTest:
)
@skip_if_lt_x_gpu(2)
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(self):
def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(
self,
):
self._test_ddp_native_mixed_precision(
gradient_as_bucket_view=True, set_grad_to_none=True
)
@ -6079,7 +6082,9 @@ class DistributedTest:
model = copy.deepcopy(BN_NET)
model = model.half()
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model.cuda(rank), device_ids=[rank])
model = nn.parallel.DistributedDataParallel(
model.cuda(rank), device_ids=[rank]
)
inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
# Check that forward/backward do not error with dtype mismatch
out = model(inp)
@ -6858,7 +6863,9 @@ class DistributedTest:
loss.backward()
all_reduce_event_name = f"{dist.get_backend()}:all_reduce"
events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
events = get_profiling_event(
all_reduce_event_name, prof, dedup_gpu_user_annotation=True
)
event_count = sum(e.count for e in events)
self.assertEqual(event_count, num_iters)
for event in events:
@ -6866,7 +6873,9 @@ class DistributedTest:
self.assertEqual(event.name, all_reduce_event_name)
broadcast_event_name = f"{dist.get_backend()}:broadcast"
broadcast_events = get_profiling_event(broadcast_event_name, prof, dedup_gpu_user_annotation=True)
broadcast_events = get_profiling_event(
broadcast_event_name, prof, dedup_gpu_user_annotation=True
)
event_count = sum(e.count for e in broadcast_events)
# Broadcast is called during rebuild_buckets
self.assertGreaterEqual(event_count, 1)
@ -6889,7 +6898,9 @@ class DistributedTest:
loss = net(inp).sum()
loss.backward()
events = get_profiling_event(all_reduce_event_name, prof, dedup_gpu_user_annotation=True)
events = get_profiling_event(
all_reduce_event_name, prof, dedup_gpu_user_annotation=True
)
self.assertGreaterEqual(len(events), 1)
self.assertGreaterEqual(events[0].count, 1)
self.assertEqual(events[0].name, all_reduce_event_name)
@ -6949,9 +6960,13 @@ class DistributedTest:
"""
with open(et_file) as f:
et = json.load(f)
pg_cfg_node = [n for n in et["nodes"] if n["name"] == "## process_group:init ##"]
pg_cfg_node = [
n for n in et["nodes"] if n["name"] == "## process_group:init ##"
]
self.assertGreaterEqual(len(pg_cfg_node), 1)
nccl_meta_nodes = [n for n in et["nodes"] if n["name"] == "record_param_comms"]
nccl_meta_nodes = [
n for n in et["nodes"] if n["name"] == "record_param_comms"
]
self.assertEqual(len(nccl_meta_nodes), 3)
per_coll_meta = defaultdict(list)
@ -6969,7 +6984,7 @@ class DistributedTest:
if collname in {"wait"}:
continue
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
self.assertEqual(attrs["pg_name"], "0") # yes this is a string
self.assertEqual(attrs["pg_desc"], "default_pg")
self.assertEqual(attrs["pg_size"], 2)
@ -6992,7 +7007,6 @@ class DistributedTest:
self.assertEqual(a1["out_msg_nelems"], 1, msg=f"{a1}")
self.assertEqual(a1["dtype"], "Int", msg=f"{a1}")
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
@ -7016,7 +7030,7 @@ class DistributedTest:
# collect ET in second profiler pass
torch_profiler_ctx2 = torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
execution_trace_observer=et
execution_trace_observer=et,
)
self._test_ddp_profiling(
profiler_ctx=torch_profiler_ctx1,
@ -7026,7 +7040,6 @@ class DistributedTest:
print(f"Execution trace saved at {fp.name}")
self._validate_execution_trace_nccl(et_file)
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
@ -7404,7 +7417,9 @@ class DistributedTest:
for num_early_join_ranks in num_uneven_ranks:
for baseline_iter in baseline_num_iters:
for offset in iteration_offsets:
mapping = dict.fromkeys(range(0, num_early_join_ranks), baseline_iter)
mapping = dict.fromkeys(
range(0, num_early_join_ranks), baseline_iter
)
# if num_early_join_ranks > 1, ranks > 0 that will join early
# iterate offset//2 more times than rank 0, to test nodes
# depleting inputs at different times.
@ -7413,11 +7428,14 @@ class DistributedTest:
if rank > 0:
mapping[rank] += offset // 2
mapping.update(
dict.fromkeys(range(num_early_join_ranks, dist.get_world_size()), baseline_iter + offset)
dict.fromkeys(
range(num_early_join_ranks, dist.get_world_size()),
baseline_iter + offset,
)
)
iteration_mappings.append(mapping)
for (test_case, iteration_mapping) in itertools.product(
for test_case, iteration_mapping in itertools.product(
models_to_test, iteration_mappings
):
if self.rank == 0:
@ -7570,7 +7588,9 @@ class DistributedTest:
int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
)
@with_dist_debug_levels(levels=["DETAIL"])
@unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
@unittest.skip(
"Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
)
def test_broadcast_object_list(self):
return self._test_broadcast_object_list()
@ -7605,7 +7625,7 @@ class DistributedTest:
device_id = self.rank
# Ensure the test works for both find_unused_parameter and broadcast_buffer settings.
for (find_unused, broadcast_buffers) in itertools.product(
for find_unused, broadcast_buffers in itertools.product(
[False, True], [False, True]
):
model = TestModel(self.rank).float().to(device_id)
@ -7954,16 +7974,16 @@ class DistributedTest:
context = nullcontext
with context():
input = torch.rand((1, ))
input = torch.rand((1,))
output = model.forward(input)
target = torch.rand((1, ))
target = torch.rand((1,))
loss = mse_loss(output, target)
loss.backward()
self.assertTrue(
not any(p.grad is None for p in model.parameters()),
"Gradients can't be None for any model parameter."
"Gradients can't be None for any model parameter.",
)
grads = torch.cat([p.grad.view(-1) for p in model.parameters()])
@ -7978,7 +7998,7 @@ class DistributedTest:
for g in gathered_grads[1:]:
self.assertTrue(
torch.allclose(gathered_grads[0], g),
"Gradients are not the same for all ranks."
"Gradients are not the same for all ranks.",
)
@with_dist_debug_levels(levels=["OFF", "INFO", "DETAIL"])
@ -8290,9 +8310,7 @@ class DistributedTest:
tensors_sparse, [400], logger=net.logger
)
else:
dist._compute_bucket_assignment_by_size(
tensors_sparse, [400]
)
dist._compute_bucket_assignment_by_size(tensors_sparse, [400])
if use_logger:
verify_ddp_error_logged(net, expected_err)
@ -8399,9 +8417,7 @@ class DistributedTest:
# Creates network with different sized embedding table on different
# ranks. This should throw an error during DDP init.
net = EmbeddingNetDifferentParams(self.rank)
self._run_test_ddp_model_with_diff_params(
net, group_to_use, group_gloo
)
self._run_test_ddp_model_with_diff_params(net, group_to_use, group_gloo)
@require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if(
@ -8516,7 +8532,7 @@ class DistributedTest:
self.assertEqual(local_net.a.weight.grad, saved_a_local_grad)
# Verify grads are the same
for (local_param, dist_param) in zip(
for local_param, dist_param in zip(
local_net.parameters(), net.parameters()
):
local_grad = local_param.grad
@ -8935,7 +8951,9 @@ class DistributedTest:
if ignore_sparse:
for module_name, module in model.named_modules():
if module == model.sub_module.embedding_net.embedding:
for parameter_name, _param in module.named_parameters(recurse=False):
for parameter_name, _param in module.named_parameters(
recurse=False
):
fqn = f"{module_name}.{parameter_name}"
sparse_embedding_fqns.append(fqn)
@ -9079,7 +9097,9 @@ class DistributedTest:
f"The {BACKEND} backend does not support DistributedDataParallel",
)
@skip_if_lt_x_gpu(2)
@unittest.skip("Test is failing, see https://github.com/pytorch/pytorch/pull/113620")
@unittest.skip(
"Test is failing, see https://github.com/pytorch/pytorch/pull/113620"
)
def test_ddp_sync_bn_training_vs_eval(self):
rank = self.rank
torch.cuda.set_device(rank)
@ -9228,7 +9248,7 @@ class DistributedTest:
loss_static = get_loss(out_static)
loss_static.backward()
self._model_step(model_static_graph)
for (p, p_static) in zip(
for p, p_static in zip(
model.parameters(), model_static_graph.parameters()
):
self.assertEqual(p, p_static)
@ -9257,7 +9277,7 @@ class DistributedTest:
model = MyModel().to(self.rank)
inp = torch.randn(1, 10, device=self.rank)
for (find_unused, static_graph) in itertools.product(
for find_unused, static_graph in itertools.product(
[True, False], [True, False]
):
ddp = DistributedDataParallel(
@ -9526,7 +9546,6 @@ class DistributedTest:
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_remove_autograd_hooks(self):
class SimulateError(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
@ -9548,7 +9567,6 @@ class DistributedTest:
else:
return self.fc1(inp)
# Run with error to trigger backward pass that marks fc1 as being marked
# ready. If we don't remove autograd hooks before running below it would
# fail on the old autograd hook.
@ -9578,9 +9596,10 @@ class DistributedTest:
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
@unittest.skip("Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751")
@unittest.skip(
"Test is failing, tracking issue at https://github.com/pytorch/pytorch/issues/102751"
)
def test_ddp_has_finalized(self):
@dataclass
class MyClass:
obj: torch.Tensor
@ -9615,10 +9634,16 @@ class DistributedTest:
(out1.sum() + out2.sum()).backward()
if self.rank == 0:
with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
with self.assertRaisesRegex(
RuntimeError,
"Expected to have finished reduction in the prior iteration",
):
ddp._check_reducer_finalized()
with self.assertRaisesRegex(RuntimeError, "Expected to have finished reduction in the prior iteration"):
with self.assertRaisesRegex(
RuntimeError,
"Expected to have finished reduction in the prior iteration",
):
ddp(input)
else:
ddp._check_reducer_finalized()
@ -9901,14 +9926,11 @@ class DistributedTest:
p.grad.data = p.grad / iters
for p_ddp, p_local in zip(
model.parameters(),
local_model.parameters()
model.parameters(), local_model.parameters()
):
self.assertTrue(
torch.allclose(
p_ddp.grad, p_local.grad
),
f"{p_ddp.grad} vs {p_local.grad}"
torch.allclose(p_ddp.grad, p_local.grad),
f"{p_ddp.grad} vs {p_local.grad}",
)
dist.barrier()
@ -10155,11 +10177,8 @@ class DistributedTest:
f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
)
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
@skip_but_pass_in_sandcastle_if(
True, "Skipped due to flakiness"
)
@skip_but_pass_in_sandcastle_if(True, "Skipped due to flakiness")
def test_ddp_hook_pickling_powerSGD(self):
hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(
process_group=None,
@ -10177,17 +10196,21 @@ class DistributedTest:
world_size = int(os.environ["WORLD_SIZE"])
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh("cuda", (world_size,))
pg = _get_default_group()
torch.cuda.set_device(self.rank)
model = TwoLinLayerNet().cuda()
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_mesh=device_mesh)
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_mesh=device_mesh
)
self.assertEqual(ddp_model.device_mesh, device_mesh)
with self.assertRaisesRegex(
RuntimeError, "Cannot specify both process_group and device_mesh arguments."
RuntimeError,
"Cannot specify both process_group and device_mesh arguments.",
):
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, process_group=pg, device_mesh=device_mesh
@ -10201,7 +10224,6 @@ class DistributedTest:
model, device_mesh=device_mesh
)
@skip_if_lt_x_gpu(2)
@require_world_size(2)
@skip_but_pass_in_sandcastle_if(
@ -10217,9 +10239,7 @@ class DistributedTest:
device_ids=[self.rank],
)
ddp_static = torch.nn.parallel.DistributedDataParallel(
model_clone,
device_ids=[self.rank],
static_graph=True
model_clone, device_ids=[self.rank], static_graph=True
)
ddp = torch.compile(ddp)
ddp_static = torch.compile(ddp_static)
@ -10271,9 +10291,11 @@ class DistributedTest:
with OpPatcher():
ddp(input).sum().backward()
def _test_skip_all_reduce_unused_parameters(
self, find_unused_parameters=False, static_graph=False, skip_all_reduce_unused_params=False,
self,
find_unused_parameters=False,
static_graph=False,
skip_all_reduce_unused_params=False,
):
class LargeNet(nn.Module):
def __init__(self) -> None:
@ -10312,11 +10334,15 @@ class DistributedTest:
test_model_1 = self._test_skip_all_reduce_unused_parameters(
find_unused_parameters=True,
static_graph=False,
skip_all_reduce_unused_params=True
skip_all_reduce_unused_params=True,
)
self.assertEqual(base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2)
self.assertEqual(test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1)
self.assertEqual(
base_model._get_ddp_logging_data().get("num_buckets_reduced"), 2
)
self.assertEqual(
test_model_1._get_ddp_logging_data().get("num_buckets_reduced"), 1
)
for i, j in zip(base_model.parameters(), test_model_1.parameters()):
self.assertEqual(i, j)

View File

@ -2,26 +2,26 @@
from contextlib import contextmanager
from datetime import timedelta
from functools import (
partial,
wraps,
)
from functools import partial, wraps
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
class MockProcessGroup(dist.ProcessGroup):
class MockProcessGroup(dist.ProcessGroup):
def __init__(self, rank, world):
super().__init__(rank, world)
def getBackendName(self):
return "mock_process_group"
def create_mock_pg(prefix_store, rank, world_size, timeout):
return MockProcessGroup(rank, world_size)
dist.Backend.register_backend('mock_process_group', create_mock_pg)
dist.Backend.register_backend("mock_process_group", create_mock_pg)
def mock_init_dist(rank, world_size):
# !!! WARNING !!!
@ -38,7 +38,9 @@ def mock_init_dist(rank, world_size):
world_size=world_size,
store=store,
group_name="fake",
timeout=timedelta(seconds=1))
timeout=timedelta(seconds=1),
)
@contextmanager
def with_dist(rank=0, world_size=2):
@ -51,6 +53,7 @@ def with_dist(rank=0, world_size=2):
finally:
dist.destroy_process_group()
def with_fake_comms(func=None, rank=0, world_size=2):
"""
Function wrapper that inits a fake process group designed for testing.
@ -63,4 +66,5 @@ def with_fake_comms(func=None, rank=0, world_size=2):
def wrapper(self, *args, **kwargs):
with with_dist(rank, world_size):
func(self, *args, **kwargs)
return wrapper

View File

@ -1,10 +1,7 @@
# mypy: allow-untyped-defs
import torch.distributed as dist
from torch._C._distributed_c10d import (
FakeProcessGroup,
)
from torch._C._distributed_c10d import FakeProcessGroup
class FakeStore(dist.Store):
@ -28,4 +25,4 @@ def _create_fake_pg(prefix_store, rank, world_size, timeout):
return FakeProcessGroup(rank, world_size)
dist.Backend.register_backend("fake", _create_fake_pg, devices=['cpu', 'cuda'])
dist.Backend.register_backend("fake", _create_fake_pg, devices=["cpu", "cuda"])

View File

@ -2,13 +2,13 @@
import sys
import threading
import weakref
from dataclasses import dataclass
from typing import Optional, Union
from functools import partial, reduce
from typing import Optional, Union
import torch
import torch.distributed as dist
import weakref
from torch._C._distributed_c10d import (
_create_work_from_future,
AllgatherOptions,
@ -16,15 +16,16 @@ from torch._C._distributed_c10d import (
AllToAllOptions,
BarrierOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
ScatterOptions,
Store,
ReduceOp,
)
from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
from torch.futures import Future
from torch.utils import _pytree as pytree
"""
TODO:
Lots of missing collectives.
@ -45,6 +46,7 @@ def ret_work(ret):
fut.set_result(ret)
return _create_work_from_future(fut)
def binop_reduce(tensors, op):
res = op(torch.stack(tensors), dim=0)
if isinstance(res, torch.Tensor):
@ -52,9 +54,11 @@ def binop_reduce(tensors, op):
# min/max return a namedtuple
return res.values
def bitwise_reduce(tensors, op):
return reduce(op, tensors)
_reduce_ops = {
ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
@ -66,6 +70,7 @@ _reduce_ops = {
ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
}
class AllToAll:
@torch.no_grad()
def work(self, data):
@ -76,6 +81,7 @@ class AllToAll:
_, input_tensor_list = data[src_rank]
output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
class AllToAllBase:
@torch.no_grad()
def work(self, data):
@ -83,35 +89,45 @@ class AllToAllBase:
for dest_rank in range(world_size):
output_buffer, _, output_split_sizes, _ = data[dest_rank]
output_indexes = self._size_cumsum(output_buffer.size(0), output_split_sizes, world_size)
output_indexes = self._size_cumsum(
output_buffer.size(0), output_split_sizes, world_size
)
for src_rank in range(world_size):
_, input_buffer, _, input_split_sizes = data[src_rank]
input_indexes = self._size_cumsum(input_buffer.size(0), input_split_sizes, world_size)
output_buffer[output_indexes[src_rank]:output_indexes[src_rank + 1]].copy_(
input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
input_indexes = self._size_cumsum(
input_buffer.size(0), input_split_sizes, world_size
)
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, list[int], None], world_size: int) -> torch.Tensor:
output_buffer[
output_indexes[src_rank] : output_indexes[src_rank + 1]
].copy_(
input_buffer[
input_indexes[dest_rank] : input_indexes[dest_rank + 1]
]
)
def _size_cumsum(
self,
buf_size: int,
sizes: Union[torch.Tensor, list[int], None],
world_size: int,
) -> torch.Tensor:
if sizes is None or len(sizes) == 0:
sizes = torch.full(
(world_size,), buf_size // world_size, dtype=torch.int64
)
sizes = torch.full((world_size,), buf_size // world_size, dtype=torch.int64)
if not isinstance(sizes, torch.Tensor):
sizes = torch.tensor(sizes, dtype=torch.int64)
assert sizes.dtype == torch.int64
sizes = torch.cumsum(
torch.cat(
(
torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes
),
dim=0
(torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes),
dim=0,
),
dim=0
dim=0,
)
return sizes
class AllReduce:
def __init__(self, op):
if op.op not in _reduce_ops:
@ -127,7 +143,9 @@ class AllReduce:
rank_0_device = data[0][i].device
# collect all data to the list and make them
# all on rank 0 device
tensors = [data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))]
tensors = [
data[src_rank][i].to(rank_0_device) for src_rank in range(0, len(data))
]
# now mimic reduce across all ranks
res = _reduce_ops[self.op](tensors)
@ -186,6 +204,7 @@ class Gather:
dest_tensor = out_tensor_list[rank]
dest_tensor.copy_(src_in_tensor_list[0])
class ReduceScatter:
def __init__(self, op):
if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
@ -254,7 +273,8 @@ class Collective:
if rank == 0:
self._start_cond.wait_for(
lambda: self._count == self._world_size or self._pg._terminate.is_set()
lambda: self._count == self._world_size
or self._pg._terminate.is_set()
)
# SystemExit is not a subclass of Exception but BaseException
# and can be distinguished from normal exception raised from program errors
@ -265,7 +285,9 @@ class Collective:
with self._done_cond:
# wait for rank 0 to finish
if rank > 0:
self._done_cond.wait_for(lambda: self._done or self._pg._terminate.is_set())
self._done_cond.wait_for(
lambda: self._done or self._pg._terminate.is_set()
)
if self._pg._terminate.is_set():
sys.exit("Test termination event occurs.")
else:
@ -287,14 +309,19 @@ class ProcessLocalGroup(dist.ProcessGroup):
with cls._coll_lock:
# pg_name is unique, we use that to record the mapping between pg and collective
if pg.pg_name not in cls._cur_coll_on_pgs:
cls._cur_coll_on_pgs[pg.pg_name] = Collective(pg.size(), collective, cls)
cls._cur_coll_on_pgs[pg.pg_name] = Collective(
pg.size(), collective, cls
)
return cls._cur_coll_on_pgs[pg.pg_name]
@classmethod
def _end_coll(cls, collective, pg):
# This is racily called by all ranks, so only one will work
with cls._coll_lock:
if pg.pg_name in cls._cur_coll_on_pgs and cls._cur_coll_on_pgs[pg.pg_name] == collective:
if (
pg.pg_name in cls._cur_coll_on_pgs
and cls._cur_coll_on_pgs[pg.pg_name] == collective
):
cls._cur_coll_on_pgs.pop(pg.pg_name)
@classmethod
@ -318,10 +345,13 @@ class ProcessLocalGroup(dist.ProcessGroup):
input_buffer: torch.Tensor,
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
opts=AllToAllOptions()
opts=AllToAllOptions(),
) -> torch.Tensor:
coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
res = coll.join(self._rank, (output_buffer, input_buffer, output_split_sizes, input_split_sizes))
res = coll.join(
self._rank,
(output_buffer, input_buffer, output_split_sizes, input_split_sizes),
)
ProcessLocalGroup._end_coll(coll, self)
return res
@ -380,21 +410,26 @@ class ProcessLocalGroup(dist.ProcessGroup):
ProcessLocalGroup._end_coll(coll, self)
return res
def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()):
def _reduce_scatter_base(
self, output_tensor, input_tensor, opts=ReduceScatterOptions()
):
tensor_list = list(torch.chunk(input_tensor, self._world_size))
return self.reduce_scatter([output_tensor], [tensor_list], opts)
def reduce_scatter_tensor_coalesced(self, output_tensors, input_tensors, opts=ReduceScatterOptions()):
def reduce_scatter_tensor_coalesced(
self, output_tensors, input_tensors, opts=ReduceScatterOptions()
):
works = [
self._reduce_scatter_base(output_tensor, input_tensor, opts)
for output_tensor, input_tensor
in zip(output_tensors, input_tensors)
for output_tensor, input_tensor in zip(output_tensors, input_tensors)
]
for work in works[:-1]:
work.wait()
return works[-1]
def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()):
def allgather_into_tensor_coalesced(
self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()
):
res = None
for o_t, i_t in zip(output_tensor_list, input_tensor_list):
res = self._allgather_base(o_t, i_t)
@ -470,7 +505,9 @@ class ThreadLocalWorld:
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {})
ThreadLocalWorld._world.world = WorldData(
None, {}, {}, {}, {}, 0, {}, {}, {}
)
return ThreadLocalWorld._world.world
@property

View File

@ -5,11 +5,13 @@ import enum
import torch
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils as dist_utils
from torch import Tensor, nn
from torch import nn, Tensor
from torch._jit_internal import Future
from torch.distributed.nn import RemoteModule
from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES
from torch.distributed.nn.api.remote_module import _RemoteModule
from torch.distributed.nn.api.remote_module import (
_REMOTE_MODULE_PICKLED_ATTRIBUTES,
_RemoteModule,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
@ -35,16 +37,19 @@ def remote_module_attributes(remote_module):
def remote_forward(remote_module, args):
return remote_module.forward(*args)
# RPC handler for running forward_async on the destination worker.
def remote_forward_async(remote_module, args):
# Since future cannot be pickled and sent over the RPC layer,
# have to wait and behave just like ``forward_sync``.
return remote_module.forward_async(*args).wait()
# RPC handler for getting training mode on the destination worker.
def get_remote_training_arg(module_rref):
return module_rref.local_value().training
class ModuleCreationMode(enum.Enum):
MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
MODULE_CTOR = "module_ctor"
@ -147,7 +152,6 @@ class RemoteModuleTest(CommonRemoteModuleTest):
):
RemoteModule(remote_device, BadModule, args, kwargs).forward()
@dist_utils.dist_init
def test_forward_async(self):
if self.rank != 0:
@ -269,11 +273,19 @@ class RemoteModuleTest(CommonRemoteModuleTest):
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
):
remote_module.train()
ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
ret1 = rpc.rpc_sync(
dst_worker_name,
get_remote_training_arg,
args=(remote_module.get_module_rref(),),
)
self.assertEqual(ret1, True)
remote_module.eval()
ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
ret2 = rpc.rpc_sync(
dst_worker_name,
get_remote_training_arg,
args=(remote_module.get_module_rref(),),
)
self.assertEqual(ret2, False)
@dist_utils.dist_init
@ -466,7 +478,9 @@ class RemoteModuleTest(CommonRemoteModuleTest):
dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
):
with TemporaryFileName() as fname:
with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"):
with self.assertRaisesRegex(
torch.jit.Error, "can only be pickled when using RPC"
):
torch.save(remote_module, fname)
@ -556,9 +570,7 @@ class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
)
args = (torch.ones(1), 2, "3")
ret1 = rpc.rpc_sync(
dst_worker1_name, remote_forward, (remote_module, args)
)
ret1 = rpc.rpc_sync(dst_worker1_name, remote_forward, (remote_module, args))
ret2 = rpc.rpc_sync(
dst_worker2_name, remote_forward, (remote_module2, args)
)
@ -613,15 +625,15 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
]
if TEST_WITH_ROCM:
errorString = (r"HIP error: invalid device ordinal\n"
r"HIP kernel errors might be asynchronously reported at some other API call, "
r"so the stacktrace below might be incorrect.\n"
r"For debugging consider passing AMD_SERIALIZE_KERNEL=3")
errorString = (
r"HIP error: invalid device ordinal\n"
r"HIP kernel errors might be asynchronously reported at some other API call, "
r"so the stacktrace below might be incorrect.\n"
r"For debugging consider passing AMD_SERIALIZE_KERNEL=3"
)
else:
errorString = r"CUDA error: invalid device ordinal"
with self.assertRaisesRegex(
RuntimeError, errorString
):
with self.assertRaisesRegex(RuntimeError, errorString):
[
m.forward()
for m in self._create_remote_module_iter(

View File

@ -1,21 +1,26 @@
# mypy: allow-untyped-defs
import random
import sys
import threading
import time
from enum import Enum
import random
import torch
import torch.nn as nn
from datetime import timedelta
from enum import Enum
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.nn as nn
import torch.testing._internal.dist_utils
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributed.rpc import RRef
from torch.testing._internal.common_utils import IS_MACOS, skip_but_pass_in_sandcastle_if
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
IS_MACOS,
skip_but_pass_in_sandcastle_if,
)
from torch.testing._internal.dist_utils import (
dist_init,
initialize_pg,
@ -25,7 +30,6 @@ from torch.testing._internal.dist_utils import (
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
# Right now we test up to 3-layer nested rpc calls.
@ -41,6 +45,7 @@ known_context_ids = set()
requires_grad_tensor = torch.ones(3, 3, requires_grad=True)
# Send rpc done info and context_id to
# dst_rank = (self.rank + rank_distance) % self.world_size
# we don't need a lock here since the GIL is held while executing remote
@ -62,6 +67,7 @@ def _check_rpc_done(rank_distance):
def _torch_ones(sizes, requires_grad=False):
return torch.ones(sizes, requires_grad=requires_grad)
# This method must be called on the rref owner, and verifies that the grad of
# rref tensor equals to the given grad.
def _compare_owner_value(context_id, rref, grad):
@ -175,6 +181,7 @@ def _run_trainer(rref_t1, t2, ps, rank_diff, sparse):
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
# This function is the same as _run_trainer, except rpc calls torchscript
# function "my_script_ref_add" instead of python function "my_rref_add"
def _run_trainer_torchscript(rref_t1, t2, ps, rank_diff, sparse):
@ -231,9 +238,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
def _exec_func(self, exec_mode, method, *args):
return self._exec_func_with_dst(
self._next_rank(), exec_mode, method, *args
)
return self._exec_func_with_dst(self._next_rank(), exec_mode, method, *args)
def _next_rank(self):
if hasattr(self, "dst_rank"):
@ -286,15 +291,11 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(worker_name(dst_rank), fn, args=(t1, t2))
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
worker_name(dst_rank), fn, args=(t1, t2)
).to_here()
ret = rpc.remote(worker_name(dst_rank), fn, args=(t1, t2)).to_here()
else:
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
# Verify graph for current context id.
ctx = dist_autograd._current_context()
@ -498,19 +499,13 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
t1 = torch.ones(3, 3, requires_grad=False)
t2 = torch.zeros(3, 3, requires_grad=False)
if ExecMode.RPC_SYNC == exec_mode:
rpc.rpc_sync(
worker_name(dst_rank), torch.add, args=(t1, t2)
)
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
elif ExecMode.REMOTE == exec_mode:
rpc.remote(
worker_name(dst_rank), torch.add, args=(t1, t2)
).to_here()
rpc.remote(worker_name(dst_rank), torch.add, args=(t1, t2)).to_here()
else:
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
ctx = dist_autograd._current_context()
send_functions = ctx._send_functions()
@ -541,9 +536,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
tensors.append(tensor)
dst_rank = self._next_rank()
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
worker_name(dst_rank), torch.stack, args=(tensors,)
)
ret = rpc.rpc_sync(worker_name(dst_rank), torch.stack, args=(tensors,))
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
worker_name(dst_rank), torch.stack, args=(tensors,)
@ -554,7 +547,9 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
self.assertEqual(torch.stack(tensors), ret)
# Verify appropriate tensors have been attached the autograd graph.
next_funcs = next(iter(dist_autograd._current_context()._send_functions().values())).next_functions
next_funcs = next(
iter(dist_autograd._current_context()._send_functions().values())
).next_functions
for i in range(len(next_funcs)):
self.assertEqual(
"torch::autograd::AccumulateGrad", next_funcs[i][0].name()
@ -585,9 +580,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
with dist_autograd.context() as context_id:
for dst_rank in dst_ranks:
rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
if nested:
rpc.rpc_sync(
worker_name(nested_dst_rank),
@ -607,9 +600,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
def _backward_no_grad_on_tensor(self, t1, t2, sparse):
with dist_autograd.context() as context_id:
loss = rpc.rpc_sync(
worker_name(self._next_rank()),
torch.add,
args=(t1, t2))
worker_name(self._next_rank()), torch.add, args=(t1, t2)
)
if sparse:
loss = torch.sparse.sum(loss)
else:
@ -650,11 +642,19 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
with dist_autograd.context() as context_id:
if sparse:
rref_t1 = rpc.remote(
rref_owner, build_sparse_tensor, args=(False, True,)
rref_owner,
build_sparse_tensor,
args=(
False,
True,
),
)
else:
rref_t1 = rpc.remote(
rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
rref_owner,
_torch_ones,
args=((3, 3),),
kwargs={"requires_grad": True},
)
if callee == rref_owner:
rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
@ -707,10 +707,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
local_ret.sum().backward()
# create rref on self
rref_t1 = rpc.remote(
worker_name(self.rank),
create_ref_fn,
args=())
rref_t1 = rpc.remote(worker_name(self.rank), create_ref_fn, args=())
# kick off forward and backward pass on three other workers (trainers)
rank_diffs = [1, 2, 3]
@ -719,7 +716,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
worker_name((self.rank + rank_diff) % self.world_size),
trainer_fn,
args=(rref_t1, t2, worker_name(self.rank), rank_diff, sparse),
) for rank_diff in rank_diffs
)
for rank_diff in rank_diffs
]
# check if the trainers have done with their backward pass
@ -877,9 +875,8 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
def _multiple_backward(self, t1, t2, sparse):
with dist_autograd.context() as context_id:
loss = rpc.rpc_sync(
worker_name(self._next_rank()),
torch.add,
args=(t1, t2))
worker_name(self._next_rank()), torch.add, args=(t1, t2)
)
if sparse:
loss = torch.sparse.sum(loss)
else:
@ -924,9 +921,7 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
def _backward_simple(self, dst, t1, t2, local_grads, sparse):
for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
ret = self._exec_func_with_dst(
dst, exec_mode, torch.add, t1, t2
)
ret = self._exec_func_with_dst(dst, exec_mode, torch.add, t1, t2)
if sparse:
loss = torch.sparse.sum(ret)
else:
@ -1005,7 +1000,6 @@ class CommonDistAutogradTest(RpcAgentTestFixture):
class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
# Sparse tests only work with TensorPipeAgent.
@dist_init
def test_graph_for_builtin_call_sparse(self):
@ -1081,7 +1075,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._backward_no_grad_on_tensor(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1091,7 +1085,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
@ -1101,7 +1095,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
@ -1115,7 +1109,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
@ -1128,7 +1122,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
@ -1141,16 +1135,12 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
def test_trainer_ps_sparse(self):
self._test_trainer_ps(
build_sparse_tensor,
_run_trainer,
True
)
self._test_trainer_ps(build_sparse_tensor, _run_trainer, True)
@dist_init
def test_backward_multiple_round_trips_sparse(self):
@ -1161,7 +1151,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
build_sparse_tensor(requires_grad=False),
build_sparse_tensor(requires_grad=True),
None,
True
True,
)
@dist_init
@ -1169,7 +1159,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._backward_different_dtypes(
build_sparse_tensor(requires_grad=True, dtype=torch.float32),
build_sparse_tensor(requires_grad=True, dtype=torch.float64),
True
True,
)
@dist_init
@ -1177,7 +1167,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._backward_simple_python_udf(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1185,7 +1175,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._backward_simple_script_call(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1193,7 +1183,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._nested_backward_accumulate_grads(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1202,7 +1192,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._backwards_nested_python_udf(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1210,7 +1200,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._mixed_requires_grad(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=False),
True
True,
)
@dist_init
@ -1218,7 +1208,7 @@ class TensorPipeAgentDistAutogradTest(CommonDistAutogradTest):
self._multiple_backward(
build_sparse_tensor(requires_grad=True),
build_sparse_tensor(requires_grad=True),
True
True,
)
@dist_init
@ -1348,17 +1338,13 @@ class DistAutogradTest(CommonDistAutogradTest):
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(worker_name(dst_rank), ret_requires_grad)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
worker_name(dst_rank), ret_requires_grad
).to_here()
ret = rpc.remote(worker_name(dst_rank), ret_requires_grad).to_here()
else:
raise ValueError(f"Unrecognized ExecMode {exec_mode}")
dist_autograd.backward(context_id, [ret.sum()])
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
# Wait for the prev rank to be done with rpc.
self._check_rpc_done(1)
@ -1421,9 +1407,7 @@ class DistAutogradTest(CommonDistAutogradTest):
t2 = torch.zeros(3, 3, requires_grad=False)
for dst_rank in dst_ranks:
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
# all worker_ids in dst_ranks should be recorded.
ctx = dist_autograd._current_context()
worker_ids = ctx._known_worker_ids()
@ -1433,12 +1417,8 @@ class DistAutogradTest(CommonDistAutogradTest):
t1.requires_grad = True
t2.requires_grad = True
for dst_rank in dst_ranks:
rpc.rpc_sync(
worker_name(dst_rank), torch.add, args=(t1, t2)
)
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), torch.add, args=(t1, t2))
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
# all worker_ids in dst_ranks should be recorded.
worker_ids = ctx._known_worker_ids()
self.assertEqual(worker_ids, dst_ranks)
@ -1448,7 +1428,9 @@ class DistAutogradTest(CommonDistAutogradTest):
with dist_autograd.context() as context_id:
t1 = torch.rand(3, 3, requires_grad=True)
t2 = torch.rand(3, 3, requires_grad=True)
loss = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2)).sum()
loss = rpc.rpc_sync(
worker_name(self._next_rank()), torch.add, args=(t1, t2)
).sum()
with torch.autograd.profiler.profile() as p:
dist_autograd.backward(context_id, [loss])
@ -1485,7 +1467,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._backward_no_grad_on_tensor(
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
False
False,
)
@dist_init
@ -1495,7 +1477,7 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
@ -1505,7 +1487,7 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
@ -1518,7 +1500,7 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
@ -1532,7 +1514,7 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
@ -1545,16 +1527,12 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3), requires_grad=True),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
def test_trainer_ps(self):
self._test_trainer_ps(
create_tensor,
_run_trainer,
False
)
self._test_trainer_ps(create_tensor, _run_trainer, False)
@dist_init
def test_trainer_ps_torchscript_functions(self):
@ -1563,9 +1541,12 @@ class DistAutogradTest(CommonDistAutogradTest):
# ref as arg is passed to pybind boundary, and the ref is not garbage
# collected by python when calling shutdown()
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
self._test_trainer_ps(create_torchscript_tensor, _run_trainer_torchscript, False)
self._test_trainer_ps(
create_torchscript_tensor, _run_trainer_torchscript, False
)
@dist_init
def test_backward_multiple_round_trips(self):
@ -1576,7 +1557,7 @@ class DistAutogradTest(CommonDistAutogradTest):
torch.rand((3, 3)),
torch.rand((3, 3), requires_grad=True),
None,
False
False,
)
@dist_init
@ -1646,9 +1627,7 @@ class DistAutogradTest(CommonDistAutogradTest):
# We don't use the result of an RPC function, as a result the
# backward pass would hang in the "FAST" mode.
rpc.rpc_sync(
worker_name(self._next_rank()), torch.add, args=(t1, t2)
)
rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
val = torch.mul(t1, t2)
@ -1679,9 +1658,7 @@ class DistAutogradTest(CommonDistAutogradTest):
# Run multiple round trips across different nodes and verify the
# original node receives an error thrown on a node deep in the chain.
val = rpc.rpc_sync(
worker_name(self._next_rank()), torch.add, args=(t2, t3)
)
val = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t2, t3))
val = rpc.rpc_sync(
worker_name(self._next_rank()), torch.mul, args=(val, t2)
)
@ -1710,9 +1687,7 @@ class DistAutogradTest(CommonDistAutogradTest):
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
res = rpc.rpc_sync(
worker_name(self._next_rank()), torch.add, args=(t1, t2)
)
res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
# Wait for all RPCs to be done.
dist.barrier()
@ -1745,9 +1720,7 @@ class DistAutogradTest(CommonDistAutogradTest):
RuntimeError,
f"Could not find autograd context with id: {context_id}",
):
res = rpc.rpc_sync(
worker_name(self._next_rank()), torch.add, args=(t1, t2)
)
res = rpc.rpc_sync(worker_name(self._next_rank()), torch.add, args=(t1, t2))
dist_autograd.backward(context_id, [res.sum()])
@dist_init
@ -1768,7 +1741,6 @@ class DistAutogradTest(CommonDistAutogradTest):
@dist_init
def test_backward_invalid_args(self):
with dist_autograd.context() as context_id:
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
dist_autograd.backward(context_id, None)
@ -1817,7 +1789,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._backward_different_dtypes(
torch.rand((3, 3), requires_grad=True, dtype=torch.float32),
torch.rand((3, 3), requires_grad=True, dtype=torch.float64),
False
False,
)
@dist_init
@ -1825,7 +1797,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._backward_simple_python_udf(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=True),
False
False,
)
@dist_init
@ -1833,7 +1805,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._backward_simple_script_call(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=True),
False
False,
)
@staticmethod
@ -1934,10 +1906,13 @@ class DistAutogradTest(CommonDistAutogradTest):
# Mark rank 0 is done in the store, since the RPC framework on
# some nodes might be broken at this point.
store.set('test_backward_node_failure_python_udf_rank0_done', "True")
store.set("test_backward_node_failure_python_udf_rank0_done", "True")
else:
# Wait for backward to finish on rank 0.
store.wait(['test_backward_node_failure_python_udf_rank0_done'], timedelta(seconds=10))
store.wait(
["test_backward_node_failure_python_udf_rank0_done"],
timedelta(seconds=10),
)
@staticmethod
def _nested_python_udf(t1, t2, dst):
@ -1952,7 +1927,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._backwards_nested_python_udf(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=True),
False
False,
)
_test_clean_context_backward_context_id = None
@ -2063,7 +2038,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._mixed_requires_grad(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=False),
False
False,
)
class TestDebugInfoFunc(Function):
@ -2210,7 +2185,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._nested_backward_accumulate_grads(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=True),
False
False,
)
@dist_init
@ -2218,7 +2193,7 @@ class DistAutogradTest(CommonDistAutogradTest):
self._multiple_backward(
torch.rand(3, 3, requires_grad=True),
torch.rand(3, 3, requires_grad=True),
False
False,
)
@dist_init(clean_shutdown=False)
@ -2228,16 +2203,21 @@ class DistAutogradTest(CommonDistAutogradTest):
t2 = torch.rand((3, 3), requires_grad=True)
with dist_autograd.context() as context_id:
loss = rpc.rpc_sync(
f'worker{self._next_rank()}',
f"worker{self._next_rank()}",
DistAutogradTest._python_udf_with_backward_error,
args=(t1, t2)).sum()
args=(t1, t2),
).sum()
try:
# Run backward in a loop multiple times.
for i in range(100):
if i < 50:
with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
dist_autograd.backward(context_id, [loss], retain_graph=True)
with self.assertRaisesRegex(
RuntimeError, "Simulate error on backward pass"
):
dist_autograd.backward(
context_id, [loss], retain_graph=True
)
elif i > 50:
# Recovered from error.
dist_autograd.backward(context_id, [loss], retain_graph=True)
@ -2270,9 +2250,10 @@ class DistAutogradTest(CommonDistAutogradTest):
@dist_init
def test_no_grad_copy(self):
'''
"""
Similar to test in test_autograd.py.
'''
"""
# create autograd function that saves grad pointer as class static
class MyFunc(Function):
static_grad_ptr = None
@ -2302,7 +2283,7 @@ class DistAutogradTest(CommonDistAutogradTest):
@staticmethod
def forward(ctx, inp1):
ctx.size = inp1.size()
return torch.tensor([1.])
return torch.tensor([1.0])
@staticmethod
def backward(ctx, grad):
@ -2312,7 +2293,9 @@ class DistAutogradTest(CommonDistAutogradTest):
b = torch.randn(5, 6, requires_grad=True)
# non-contiguous grad should be copied
with dist_autograd.context() as context_id:
dist_autograd.backward(context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))])
dist_autograd.backward(
context_id, [NonContGradFunc.apply(MyFunc.apply(a, b))]
)
grads = dist_autograd.get_gradients(context_id)
self.assertFalse(grads[a].data_ptr() == MyFunc.static_grad_ptr)
self.assertFalse(grads[b].data_ptr() == MyFunc.static_grad_ptr)
@ -2516,9 +2499,7 @@ class DistAutogradTest(CommonDistAutogradTest):
dist_autograd.backward(context_id, [loss])
self.assertTrue(
rpc.rpc_sync(
dst,
_compare_owner_value,
args=(context_id, rref, t3.grad)
dst, _compare_owner_value, args=(context_id, rref, t3.grad)
)
)
@ -2602,9 +2583,7 @@ class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
with dist_autograd.context() as context_id:
for dst_rank in dst_ranks:
rpc.rpc_sync(worker_name(dst_rank), func, args=rpc_args)
rpc.rpc_sync(
worker_name(dst_rank), _set_rpc_done, args=(context_id, 1)
)
rpc.rpc_sync(worker_name(dst_rank), _set_rpc_done, args=(context_id, 1))
# the thread's context id should be cleaned up
with self.assertRaises(RuntimeError):
dist_autograd._retrieve_context(context_id)
@ -2625,7 +2604,9 @@ class FaultyAgentDistAutogradTest(RpcAgentTestFixture):
@dist_init
def test_verify_backend_options(self):
self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
self.assertEqual(
self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
)
self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
@ -2645,7 +2626,6 @@ class WrapperModule(nn.Module):
class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
@skip_if_lt_x_gpu(4)
def test_device_maps_backward_pass(self):
options = self.rpc_backend_options
@ -2690,7 +2670,6 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
@skip_if_lt_x_gpu(4)
def test_dist_autograd_sync_streams(self):
options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)
@ -2747,10 +2726,9 @@ class TensorPipeCudaDistAutogradTest(RpcAgentTestFixture):
local_layers = [l.to(0) for l in layers]
remote_layers = [
rpc.remote(
worker_name(rank),
WrapperModule,
args=(layers[rank - 1], rank)
) for rank in range(1, self.world_size)
worker_name(rank), WrapperModule, args=(layers[rank - 1], rank)
)
for rank in range(1, self.world_size)
]
x = torch.randn(5000, 2000).to(0)

View File

@ -204,7 +204,9 @@ class DistOptimizerTest(RpcAgentTestFixture):
self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
self._test_dist_optim_base(optim.SGD, lr=0.05)
self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True)
self._test_dist_optim_base(
optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True
)
self._test_dist_optim_base(optim.Adadelta, rho=0.95)
self._test_dist_optim_base(optim.RMSprop, lr=0.05)
self._test_dist_optim_base(optim.Adamax, lr=0.05)

View File

@ -12,12 +12,11 @@ import torch
import torch.distributed.rpc as rpc
import torch.nn as nn
from torch import optim
from torch.testing._internal.dist_utils import (
dist_init,
worker_name,
from torch.testing._internal.dist_utils import dist_init, worker_name
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
batch_size = 20
in_features = 100
@ -30,7 +29,6 @@ def timed_log(text):
class BatchUpdateParameterServer:
def __init__(self, batch_update_size):
self.model = nn.Linear(in_features, out_features)
self.lock = threading.Lock()
@ -54,7 +52,9 @@ class BatchUpdateParameterServer:
else:
p.grad += g
with self.lock:
timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates")
timed_log(
f"PS got {self.curr_update_size}/{self.batch_update_size} updates"
)
self.curr_update_size += 1
fut = self.future_model
@ -72,7 +72,6 @@ class BatchUpdateParameterServer:
class Trainer:
def __init__(self, ps_rref):
self.ps_rref = ps_rref
self.loss_fn = nn.L1Loss()
@ -107,18 +106,19 @@ def run_ps(trainers):
timed_log("Start training")
start = perf_counter()
ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
futs = [rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers]
futs = [
rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) for trainer in trainers
]
torch.futures.wait_all(futs)
stop = perf_counter()
timed_log("Finish training")
timed_log(f"Time spent training: {stop - start}s")
class ParameterServerTest(RpcAgentTestFixture):
class ParameterServerTest(RpcAgentTestFixture):
@dist_init(setup_rpc=False)
def test_batch_updating_parameter_server(self):
if self.rank != 0:
rpc.init_rpc(
name=worker_name(self.rank),

View File

@ -11,16 +11,19 @@ import torch.distributed.rpc as rpc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_sync, rpc_async, remote
from torch.distributed.rpc import remote, rpc_async, rpc_sync, RRef
from torch.distributions import Categorical
from torch.testing._internal.dist_utils import dist_init, worker_name
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
TOTAL_EPISODE_STEP = 5000
GAMMA = 0.1
SEED = 543
def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
@ -43,6 +46,7 @@ class Policy(nn.Module):
Copying the code to make these two examples independent.
See https://github.com/pytorch/examples/tree/master/reinforcement_learning
"""
def __init__(self) -> None:
super().__init__()
self.affine1 = nn.Linear(4, 128)
@ -67,6 +71,7 @@ class DummyEnv:
tests in this file. It is designed to run for a set max number of iterations,
returning random states and rewards at each step.
"""
def __init__(self, state_dim=4, num_iters=10, reward_threshold=475.0):
self.state_dim = state_dim
self.num_iters = num_iters
@ -96,6 +101,7 @@ class Observer:
select an action. Then, the observer applies the action to its environment
and reports the reward to the agent.
"""
def __init__(self) -> None:
self.id = rpc.get_worker_info().id
self.env = DummyEnv()
@ -171,8 +177,9 @@ class Agent:
rpc_async(
ob_rref.owner(),
_call_method,
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps)
) for ob_rref in self.ob_rrefs
args=(Observer.run_episode, ob_rref, self.agent_rref, n_steps),
)
for ob_rref in self.ob_rrefs
]
# wait until all observers have finished this episode

View File

@ -1,32 +1,36 @@
# mypy: allow-untyped-defs
import torch
import time
import torch
import torch.distributed.rpc as rpc
from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
from torch.testing._internal.dist_utils import (
dist_init,
wait_until_pending_futures_and_users_flushed,
wait_until_owners_and_forks_on_rank,
wait_until_pending_futures_and_users_flushed,
worker_name,
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
def my_sleep_func(seconds=1):
time.sleep(seconds)
return torch.mul(torch.tensor(1), torch.tensor(1))
@torch.jit.script
def my_script_func(tensor):
return torch.add(tensor, tensor)
def add_rref_to_value(rref, value):
return rref.to_here() + value
class FaultyAgentRpcTest(RpcAgentTestFixture):
class FaultyAgentRpcTest(RpcAgentTestFixture):
# no faulty_messages defined so this fails all retryable messages - see
# faulty_rpc_agent_test_fixture.py for the list of retryable messages.
@dist_init(messages_to_delay={})
@ -36,22 +40,30 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
dst_worker_c = worker_name((self.rank + 2) % self.world_size)
# Worker0 sends RPC to Worker1 and creates an RRef there
rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2)))
rref = rpc.remote(
dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))
)
# Worker0 sends an RPC to Worker2 with the RRef as an arg
rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
# check if the output is as expected
self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2)))
self.assertEqual(
rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))
)
# explicitly delete all User RRefs
_delete_all_user_and_unforked_owner_rrefs()
@dist_init
def test_verify_backend_options(self):
self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
self.assertEqual(
self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE
)
self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
self.assertEqual(
self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC
)
@dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
def test_custom_faulty_messages(self):
@ -66,7 +78,9 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
@dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
def test_custom_messages_to_delay(self):
self.assertEqual(self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5})
self.assertEqual(
self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}
)
def _test_remote_message_dropped_pickle(self, dst=None):
if self.rank != 0:
@ -95,7 +109,6 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
def test_remote_message_dropped_pickle_to_self(self):
self._test_remote_message_dropped_pickle(self.rank)
def _test_remote_message_dropped_timeout(self, func, args, dst=None):
if self.rank != 0:
return
@ -297,22 +310,20 @@ class FaultyAgentRpcTest(RpcAgentTestFixture):
with self.assertRaisesRegex(RuntimeError, expected_error):
rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
fut = rpc.rpc_async(
dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1
)
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()
# Ensure that the currently set default timeout is large enough such
# that RPCs with delays still complete.
fut = rpc.rpc_async(
dst_worker, my_script_func, args=(torch.tensor(1),)
)
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
fut.wait()
# Ensure timeout if we set a new default and don't override
rpc._set_rpc_timeout(0.001)
fut = rpc.rpc_async(
dst_worker, my_script_func, args=(torch.tensor(1),)
)
fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),))
with self.assertRaisesRegex(RuntimeError, expected_error):
fut.wait()

View File

@ -6,13 +6,16 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
# The following message types are currently retried in the RREF protocol and
# distributed autograd. Thus only these messages should be tested with the
# Faulty RPC Agent.
retryable_message_types = ["RREF_FORK_REQUEST",
"RREF_CHILD_ACCEPT",
"RREF_USER_DELETE",
"CLEANUP_AUTOGRAD_CONTEXT_REQ"]
retryable_message_types = [
"RREF_FORK_REQUEST",
"RREF_CHILD_ACCEPT",
"RREF_USER_DELETE",
"CLEANUP_AUTOGRAD_CONTEXT_REQ",
]
# The following messages incur the corresponding delay in seconds while being
# processed in FaultyTensorPipeAgent's enqueueSend() function.
@ -21,6 +24,7 @@ default_messages_to_delay = {
"SCRIPT_CALL": 1.5, # Script/Builtin
}
class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -29,9 +33,7 @@ class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
@property
def rpc_backend(self):
return rpc.backend_registry.BackendType[
"FAULTY_TENSORPIPE"
]
return rpc.backend_registry.BackendType["FAULTY_TENSORPIPE"]
@property
def rpc_backend_options(self):
@ -54,7 +56,7 @@ class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
error_regexes = [
"Exception in thread pool task",
"Connection reset by peer",
"Connection closed by peer"
"Connection closed by peer",
]
return "|".join([f"({error_str})" for error_str in error_regexes])

View File

@ -32,9 +32,8 @@ def fork_add(t1, t2, dst: str):
class JitDistAutogradTest(RpcAgentTestFixture):
@dist_init
def test_get_gradients(self):
@torch.jit.script
def dist_get_gradients(context_id: int) -> (dict[Tensor, Tensor]):
def dist_get_gradients(context_id: int) -> dict[Tensor, Tensor]:
return dist_autograd.get_gradients(context_id)
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import time
import io
import time
from typing import Any
import torch
@ -9,8 +9,9 @@ import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch import Tensor
from torch.autograd.profiler import record_function
from torch.autograd.profiler_legacy import profile as _profile
from torch.distributed.rpc import RRef
from torch.distributed.rpc.internal import RPCExecMode, _build_rpc_profiling_key
from torch.distributed.rpc.internal import _build_rpc_profiling_key, RPCExecMode
from torch.futures import Future
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.dist_utils import (
@ -23,11 +24,11 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.autograd.profiler_legacy import profile as _profile
def rref_isinstance(rref, cls_to_check):
return isinstance(rref.local_value(), cls_to_check)
def sleep(t):
time.sleep(t)
@ -140,10 +141,12 @@ def no_arg():
def one_arg(value):
return value + 1
@torch.jit.script
def script_add_ones(x):
return torch.add(x, torch.ones(1))
@torch.jit.script
def script_add_ones_with_record_function(x, block: str):
with record_function(block):
@ -154,16 +157,15 @@ def script_add_ones_with_record_function(x, block: str):
def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor:
t: Tensor = torch.ones(1)
with record_function(block):
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
# Extra operator call to avoid de-duplication of the next async call
# see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279
zero = torch.zeros_like(t)
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, ))
fut2 = rpc.rpc_async(dst_worker_name, script_add_ones, (t,))
res = fut1.wait() + fut2.wait() + zero
return res
@torch.jit.script
def script_fork_wait_udf(tensor):
fut = torch.jit._fork(script_add_ones, tensor)
@ -196,7 +198,9 @@ def script_fork_wait_throw(invalue):
@torch.jit.script
def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_worker_name: str) -> Tensor:
def call_rpc_with_profiling(
record: torch.classes.profiler._RecordFunction, dst_worker_name: str
) -> Tensor:
# Call rpc_async from within ScriptFunction and ensure that we can attach
# profiling callbacks. Note that handle here is a Tensor representation of
# RecordFunction.
@ -205,9 +209,14 @@ def call_rpc_with_profiling(record: torch.classes.profiler._RecordFunction, dst_
ret = fut.wait()
return ret
@torch.jit.script
def call_rpc_torchscript_with_record_function(dst_worker_name: str, block: str) -> Tensor:
fut = rpc.rpc_async(dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block))
def call_rpc_torchscript_with_record_function(
dst_worker_name: str, block: str
) -> Tensor:
fut = rpc.rpc_async(
dst_worker_name, script_add_ones_with_record_function, (torch.tensor(1), block)
)
return fut.wait()
@ -311,9 +320,7 @@ class FutureTypingTest:
def future_return_to_python(
dst_rank: int, inputs: tuple[Tensor, Tensor]
) -> Future[Tensor]:
return rpc.rpc_async(
f"worker{dst_rank}", two_args_two_kwargs, inputs
)
return rpc.rpc_async(f"worker{dst_rank}", two_args_two_kwargs, inputs)
fut_res = future_return_to_python(dst_rank, inputs)
self.assertEqual(fut_res.wait(), expected_res)
@ -524,6 +531,7 @@ def script_rpc_async_call(
ret = fut.wait()
return ret
@torch.jit.script
def script_rpc_sync_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
@ -531,6 +539,7 @@ def script_rpc_sync_call(
res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
return res
@torch.jit.script
def script_rpc_remote_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
@ -538,6 +547,7 @@ def script_rpc_remote_call(
rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
return rref_res.to_here()
class JitRpcOpTest:
# Call functions remotely from Script.
@dist_init
@ -550,10 +560,12 @@ class JitRpcOpTest:
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {}
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
ret = script_op(
dst_worker_name, args, kwargs
)
for script_op in [
script_rpc_async_call,
script_rpc_sync_call,
script_rpc_remote_call,
]:
ret = script_op(dst_worker_name, args, kwargs)
self.assertEqual(ret, torch.tensor([10, 10]))
@dist_init
@ -566,10 +578,12 @@ class JitRpcOpTest:
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {"first_kwarg": torch.tensor([2, 2])}
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
ret = script_op(
dst_worker_name, args, kwargs
)
for script_op in [
script_rpc_async_call,
script_rpc_sync_call,
script_rpc_remote_call,
]:
ret = script_op(dst_worker_name, args, kwargs)
self.assertEqual(ret, torch.tensor([9, 9]))
@dist_init
@ -584,10 +598,12 @@ class JitRpcOpTest:
"first_kwarg": torch.tensor([2, 2]),
"second_kwarg": torch.tensor([3, 3]),
}
for script_op in [script_rpc_async_call, script_rpc_sync_call, script_rpc_remote_call]:
ret = script_op(
dst_worker_name, args, kwargs
)
for script_op in [
script_rpc_async_call,
script_rpc_sync_call,
script_rpc_remote_call,
]:
ret = script_op(dst_worker_name, args, kwargs)
self.assertEqual(ret, torch.tensor([8, 8]))
@dist_init
@ -618,9 +634,7 @@ class JitRpcOpTest:
ret = fut.wait()
return ret
ret = script_rpc_async_call_with_assorted_types(
dst_worker_name
)
ret = script_rpc_async_call_with_assorted_types(dst_worker_name)
self.assertEqual(ret, (torch.tensor([4, 4]), "str_arg_str_kwarg", 4))
@dist_init
@ -639,9 +653,7 @@ class JitRpcOpTest:
ret = fut.wait()
return ret
ret = script_rpc_async_call_without_kwargs_passed(
dst_worker_name
)
ret = script_rpc_async_call_without_kwargs_passed(dst_worker_name)
self.assertEqual(ret, 0)
@dist_init
@ -659,9 +671,7 @@ class JitRpcOpTest:
ret = fut.wait()
return ret
ret = script_rpc_async_call_without_args_kwargs_passed(
dst_worker_name
)
ret = script_rpc_async_call_without_args_kwargs_passed(dst_worker_name)
self.assertEqual(ret, 0)
@dist_init
@ -730,9 +740,7 @@ class JitRpcOpTest:
with self.assertRaisesRegex(
RuntimeError, "Unknown keyword argument 'third_kwarg'"
):
ret = script_rpc_async_call_with_unexpected_kwarg(
dst_worker_name
)
ret = script_rpc_async_call_with_unexpected_kwarg(dst_worker_name)
self.assertEqual(ret, 0)
@dist_init
@ -915,9 +923,7 @@ class JitRpcTest(
# Python 3.5 and Python 3.6 throw different error message, the only
# common word can be greped is "pickle".
with self.assertRaisesRegex(TypeError, "pickle"):
rpc.rpc_async(
dst_worker_name, my_local_script_module.forward, args=()
)
rpc.rpc_async(dst_worker_name, my_local_script_module.forward, args=())
@dist_init
def test_remote_script_module(self):
@ -1005,9 +1011,7 @@ class JitRpcTest(
rpc._disable_jit_rref_pickle()
out1 = rpc.rpc_sync(
dst_name,
load_script_module_with_pickled_rref,
args=(f.getvalue(),)
dst_name, load_script_module_with_pickled_rref, args=(f.getvalue(),)
)
out2 = m2()
self.assertEqual(out1, out2)
@ -1150,7 +1154,9 @@ class JitRpcTest(
# After that, this test should be modified to validate the function time.
events = prof.function_events
function_event = get_function_event(events, prof_key)
self.assertTrue(torch._jit_internal._qualified_name(one_arg) in function_event.name)
self.assertTrue(
torch._jit_internal._qualified_name(one_arg) in function_event.name
)
@dist_init
def test_rpc_async_jit_profiled(self):
@ -1162,9 +1168,7 @@ class JitRpcTest(
args = (torch.tensor([1, 1]), torch.tensor([2, 2]))
kwargs = {}
with _profile() as prof:
script_rpc_async_call(
dst_worker_name, args, kwargs
)
script_rpc_async_call(dst_worker_name, args, kwargs)
# Ensure rpc_async call is profiled
function_events = prof.function_events
@ -1358,10 +1362,9 @@ class JitRpcTest(
num = 20
rrefs = [
rpc.remote(
dst1,
async_add,
args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
) for i in range(num)
dst1, async_add, args=(dst2, torch.ones(2, 2), torch.ones(2, 2) * i)
)
for i in range(num)
]
for i in range(num):

View File

@ -7,8 +7,8 @@ from torch import Tensor
from torch.distributed.rpc import RRef
from torch.testing._internal.dist_utils import (
dist_init,
wait_until_pending_futures_and_users_flushed,
worker_name,
wait_until_pending_futures_and_users_flushed
)
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
@ -64,14 +64,17 @@ def rpc_async_call_future_ret(
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
return fut
@torch.jit.script
def rref_to_here(rref_var: RRef[Tensor]) -> Tensor:
return rref_var.to_here()
@torch.jit.script
def rref_to_here_with_timeout(rref_var: RRef[Tensor], timeout: float) -> Tensor:
return rref_var.to_here(timeout)
@torch.jit.script
def rpc_async_with_rref_arg(dst_worker_name: str, args: tuple[RRef[Tensor]]) -> Tensor:
fut = rpc.rpc_async(dst_worker_name, rref_to_here, args)
@ -84,6 +87,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
Run tests for rpc_async in JIT under the faulty agent test fixture to test
arbitrary timeouts.
"""
@dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
def test_timeout_in_torchscript_function(self):
# Call rpc_async + fut.wait() in torchscript function and ensure that
@ -108,9 +112,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
# is less than the RPC takes to execute.
rpc._set_rpc_timeout(0.001)
with self.assertRaisesRegex(RuntimeError, expected_error):
script_rpc_async_call(
dst_worker_name, args, kwargs
)
script_rpc_async_call(dst_worker_name, args, kwargs)
# Ensure that we run to completion if zero timeout is specified.
ret = rpc_async_call_with_timeout(dst_worker_name, args, kwargs, 0)
@ -198,7 +200,7 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
# Call RPC with RRef arg in JIT, which will go through JIT pickling and
# ensure error is raised.
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
rpc_async_with_rref_arg(dst_worker, (rref, ))
rpc_async_with_rref_arg(dst_worker, (rref,))
@dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
def test_rref_timeout_pickle_script_func(self):
@ -214,4 +216,4 @@ class JitFaultyAgentRpcTest(RpcAgentTestFixture):
wait_until_pending_futures_and_users_flushed()
# Call RPC with script function that takes RRef, ensure timeout during pickling
with self.assertRaisesRegex(RuntimeError, "RRef creation"):
rpc.rpc_sync(dst_worker, rref_to_here, args=(rref, ))
rpc.rpc_sync(dst_worker, rref_to_here, args=(rref,))

File diff suppressed because it is too large Load Diff

View File

@ -1,27 +1,21 @@
# mypy: allow-untyped-defs
import torch.distributed.rpc as rpc
from torch.testing._internal.common_distributed import tp_transports
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.testing._internal.common_distributed import (
tp_transports,
)
class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
@property
def rpc_backend(self):
return rpc.backend_registry.BackendType[
"TENSORPIPE"
]
return rpc.backend_registry.BackendType["TENSORPIPE"]
@property
def rpc_backend_options(self):
return rpc.backend_registry.construct_rpc_backend_options(
self.rpc_backend,
init_method=self.init_method,
_transports=tp_transports()
self.rpc_backend, init_method=self.init_method, _transports=tp_transports()
)
def get_shutdown_error_regex(self):

View File

@ -6,9 +6,9 @@ import unittest
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
find_free_port,
IS_SANDCASTLE,
TEST_WITH_DEV_DBG_ASAN,
)
from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
CudaDdpComparisonTest,
@ -21,15 +21,24 @@ from torch.testing._internal.distributed.nn.api.remote_module_test import (
ThreeWorkersRemoteModuleTest,
)
from torch.testing._internal.distributed.rpc.dist_autograd_test import (
DistAutogradTest,
CudaDistAutogradTest,
DistAutogradTest,
FaultyAgentDistAutogradTest,
TensorPipeAgentDistAutogradTest,
TensorPipeCudaDistAutogradTest
TensorPipeCudaDistAutogradTest,
)
from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
DistOptimizerTest,
)
from torch.testing._internal.distributed.rpc.examples.parameter_server_test import (
ParameterServerTest,
)
from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
ReinforcementLearningRpcTest,
)
from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
FaultyAgentRpcTest,
)
from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
JitDistAutogradTest,
)
@ -40,18 +49,11 @@ from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
FaultyAgentRpcTest,
)
from torch.testing._internal.distributed.rpc.rpc_test import (
CudaRpcTest,
RpcTest,
TensorPipeAgentRpcTest,
TensorPipeAgentCudaRpcTest,
)
from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ParameterServerTest
from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
ReinforcementLearningRpcTest,
TensorPipeAgentRpcTest,
)
@ -61,15 +63,17 @@ def _check_and_set_tcp_init():
# different ports.
use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
if use_tcp_init == "1":
os.environ["MASTER_ADDR"] = '127.0.0.1'
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(find_free_port())
def _check_and_unset_tcp_init():
use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
if use_tcp_init == "1":
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
# The tests for the RPC module need to cover multiple possible combinations:
# - different aspects of the API, each one having its own suite of tests;
# - different agents (ProcessGroup, TensorPipe, ...);
@ -80,8 +84,10 @@ def _check_and_unset_tcp_init():
# we call the generate_tests function of this file, passing to it a fixture for
# the agent, which then gets mixed-in with each test suite.
@unittest.skipIf(
TEST_WITH_DEV_DBG_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
TEST_WITH_DEV_DBG_ASAN,
"Skip ASAN as torch + multiprocessing spawn have known issues",
)
class SpawnHelper(MultiProcessTestCase):
def setUp(self):
@ -169,8 +175,10 @@ def generate_tests(
for test_class in tests:
if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
print(
f'Skipping test {test_class} on sandcastle for the following reason: '
'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr)
f"Skipping test {test_class} on sandcastle for the following reason: "
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
continue
name = f"{prefix}{test_class.__name__}"