[BE][Easy] enable UFMT for torch/distributed/ (#128870)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin, https://github.com/wconstab
This commit is contained in:
Xuehai Pan
2024-06-22 18:43:13 +08:00
committed by PyTorch MergeBot
parent e165a5971f
commit 94dc3253a0
37 changed files with 587 additions and 303 deletions

View File

@ -1387,33 +1387,6 @@ exclude_patterns = [
'torch/contrib/_tensorboard_vis.py',
"torch/cuda/_gpu_trace.py",
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
'torch/distributed/__init__.py',
'torch/distributed/_composable_state.py',
'torch/distributed/_sharded_tensor/__init__.py',
'torch/distributed/_sharding_spec/__init__.py',
'torch/distributed/_tools/__init__.py',
'torch/distributed/_tools/memory_tracker.py',
'torch/distributed/argparse_util.py',
'torch/distributed/c10d_logger.py',
'torch/distributed/collective_utils.py',
'torch/distributed/constants.py',
'torch/distributed/distributed_c10d.py',
'torch/distributed/examples/memory_tracker_example.py',
'torch/distributed/launch.py',
'torch/distributed/launcher/__init__.py',
'torch/distributed/launcher/api.py',
'torch/distributed/logging_handlers.py',
'torch/distributed/nn/__init__.py',
'torch/distributed/nn/api/__init__.py',
'torch/distributed/nn/api/remote_module.py',
'torch/distributed/nn/functional.py',
'torch/distributed/nn/jit/__init__.py',
'torch/distributed/nn/jit/instantiator.py',
'torch/distributed/nn/jit/templates/__init__.py',
'torch/distributed/nn/jit/templates/remote_module_template.py',
'torch/distributed/remote_device.py',
'torch/distributed/rendezvous.py',
'torch/distributed/run.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
'torch/futures/__init__.py',

View File

@ -1,12 +1,10 @@
# mypy: allow-untyped-defs
import os
import sys
from enum import Enum
import pdb
import io
import sys
import torch
def is_available() -> bool:
"""
Return ``True`` if the distributed package is available.
@ -32,31 +30,31 @@ DistStoreError = torch._C._DistStoreError
if is_available():
from torch._C._distributed_c10d import (
Store,
FileStore,
TCPStore,
ProcessGroup as ProcessGroup,
Backend as _Backend,
PrefixStore,
Reducer,
Logger,
BuiltinCommHookType,
GradBucket,
Work as _Work,
_DEFAULT_FIRST_BUCKET_BYTES,
_register_comm_hook,
_register_builtin_comm_hook,
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_verify_params_across_processes,
_ControlCollectives,
_DEFAULT_FIRST_BUCKET_BYTES,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_StoreCollectives,
_test_python_store,
_verify_params_across_processes,
Backend as _Backend,
BuiltinCommHookType,
DebugLevel,
FileStore,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup as ProcessGroup,
Reducer,
set_debug_level,
set_debug_level_from_env,
_make_nccl_premul_sum,
_ControlCollectives,
_StoreCollectives,
Store,
TCPStore,
Work as _Work,
)
class _DistributedPdb(pdb.Pdb):
@ -66,10 +64,11 @@ if is_available():
Usage:
_DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
@ -101,37 +100,31 @@ if is_available():
del guard
if sys.platform != "win32":
from torch._C._distributed_c10d import (
HashStore,
_round_robin_process_groups,
)
from torch._C._distributed_c10d import _round_robin_process_groups, HashStore
from .distributed_c10d import * # noqa: F403
from .device_mesh import DeviceMesh, init_device_mesh
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_reduce_scatter_base,
_create_process_group_wrapper,
_rank_not_in_group,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
get_node_local_rank,
)
from .remote_device import _remote_device
from .rendezvous import (
rendezvous,
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)
from .remote_device import _remote_device
from .device_mesh import init_device_mesh, DeviceMesh
set_debug_level_from_env()
else:
@ -143,4 +136,5 @@ else:
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]

View File

@ -5,6 +5,7 @@ import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.distributed_c10d import ReduceOp
from ._fsdp_common import (
_get_dim0_padded_size,
_raise_assert_with_print,

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import math
import traceback
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, cast, List, Optional

View File

@ -4,10 +4,10 @@ from typing import List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
from torch.distributed.device_mesh import _get_device_handle
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
from ._fsdp_state import _get_module_fsdp_state

View File

@ -7,12 +7,12 @@ from typing import Any, cast, List, Optional, Sequence, Tuple
import torch
import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.device_mesh import _mesh_resources
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_common import (
_chunk_with_empty,
@ -24,6 +24,7 @@ from ._fsdp_common import (
HSDPMeshInfo,
)
"""
[Note: FSDP tensors]
FSDP considers the following tensors:

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
import logging
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
import torch
@ -12,6 +11,7 @@ from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicat
from torch.profiler import record_function
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
from ._fsdp_collectives import (
AllGatherResult,
@ -22,6 +22,7 @@ from ._fsdp_collectives import (
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
logger = logging.getLogger("torch.distributed._composable.fsdp")
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import functools
import logging
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import torch
@ -14,13 +13,16 @@ from torch.distributed._composable_state import (
)
from torch.distributed.utils import _to_kwargs
from torch.utils._pytree import tree_flatten, tree_map
from ._fsdp_api import MixedPrecisionPolicy
from ._fsdp_common import _cast_fp_tensor, TrainingState
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
if TYPE_CHECKING:
from ._fsdp_param import FSDPParam
logger = logging.getLogger("torch.distributed._composable.fsdp")

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import functools
from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
import torch

View File

@ -8,7 +8,6 @@ from torch.distributed._composable.contract import contract
from torch.distributed._composable_state import _get_module_state, _insert_module_state
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_init_buffer_state,
_init_core_state,

View File

@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel
from .contract import _get_registry, contract
_ROOT_MODULE_PREFIX = ""

View File

@ -1,15 +1,14 @@
# mypy: allow-untyped-defs
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
if TYPE_CHECKING:
from torch._C._distributed_c10d import _DistributedBackendOptions, Backend

View File

@ -11,6 +11,7 @@ from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
from . import _functional_collectives_impl as fun_col_impl
try:
from torch.utils._cxx_pytree import tree_map_only
except ImportError:
@ -1134,6 +1135,7 @@ from torch.distributed.distributed_c10d import (
reduce_scatter_tensor as legacy_reducescatter,
)
# This dict should contain sets of functions that dynamo is allowed to remap.
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
traceable_collective_remaps = {

View File

@ -4,6 +4,7 @@ from typing import List, Optional
import torch
import torch.distributed.distributed_c10d as c10d
"""
This file contains the op impls for the legacy (c10d_functional) functional collectives.
These impls simply call into the native (_c10d_functional) functional collectives.

View File

@ -1,11 +1,12 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
import warnings
import torch
from torch.distributed._shard.sharded_tensor import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
@ -15,4 +16,6 @@ with warnings.catch_warnings():
stacklevel=2,
)
sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor
sys.modules[
"torch.distributed._sharded_tensor"
] = torch.distributed._shard.sharded_tensor

View File

@ -1,11 +1,12 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
import warnings
import torch
from torch.distributed._shard.sharding_spec import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
@ -16,4 +17,6 @@ with warnings.catch_warnings():
)
import torch.distributed._shard.sharding_spec as _sharding_spec
sys.modules['torch.distributed._sharding_spec'] = _sharding_spec
sys.modules["torch.distributed._sharding_spec"] = _sharding_spec

View File

@ -22,6 +22,7 @@ import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor

View File

@ -1,3 +1,3 @@
from .mem_tracker import MemTracker
from .memory_tracker import MemoryTracker
from .mod_tracker import ModTracker
from .mem_tracker import MemTracker

View File

@ -1,24 +1,14 @@
# mypy: allow-untyped-defs
from collections import defaultdict
from itertools import chain
import operator
import pickle
from typing import (
Any,
Callable,
Dict,
List,
no_type_check,
Sequence,
TYPE_CHECKING,
)
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING
import torch
import torch.nn as nn
from torch.utils._python_dispatch import TorchDispatchMode
import operator
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
@ -234,6 +224,7 @@ class MemoryTracker:
def _create_pre_forward_hook(self, name: str) -> Callable:
"""Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
self._cur_module_name = f"{name}.forward"
if (

View File

@ -15,9 +15,9 @@ from typing_extensions import ParamSpec
import torch
import torch.distributed as dist
from torch.distributed.logging_handlers import _log_handlers
__all__: List[str] = []
_DEFAULT_DESTINATION = "default"
@ -36,7 +36,9 @@ def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Lo
return logger
def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[logging.Handler, str]:
def _get_logging_handler(
destination: str = _DEFAULT_DESTINATION,
) -> Tuple[logging.Handler, str]:
log_handler = _log_handlers[destination]
log_handler_name = type(log_handler).__name__
return (log_handler, log_handler_name)
@ -69,8 +71,10 @@ def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
}
return msg_dict
_T = TypeVar('_T')
_P = ParamSpec('_P')
_T = TypeVar("_T")
_P = ParamSpec("_P")
def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(func)

View File

@ -14,8 +14,10 @@ from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar,
import torch.distributed as dist
T = TypeVar("T")
@dataclass
class SyncPayload(Generic[T]):
stage_name: Optional[str]
@ -23,6 +25,7 @@ class SyncPayload(Generic[T]):
payload: T
exception: Optional[Exception] = None
def broadcast(
data_or_fn: Union[T, Callable[[], T]],
*,
@ -55,10 +58,12 @@ def broadcast(
"""
if not success and data_or_fn is not None:
raise AssertionError("Data or Function is expected to be None if not successful")
raise AssertionError(
"Data or Function is expected to be None if not successful"
)
payload: Optional[T] = None
exception : Optional[Exception] = None
exception: Optional[Exception] = None
# if no pg is passed then execute if rank is 0
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
# determine if it is an executable function or data payload only
@ -119,7 +124,7 @@ def all_gather(
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
"""
payload: Optional[T] = None
exception : Optional[Exception] = None
exception: Optional[Exception] = None
success = True
# determine if it is an executable function or data payload only
if callable(data_or_fn):
@ -161,7 +166,8 @@ def all_gather(
if len(exception_list) > 0:
raise RuntimeError( # type: ignore[misc]
error_msg, exception_list) from exception_list[0]
error_msg, exception_list
) from exception_list[0]
return ret_list
else:
if not sync_obj.success:

View File

@ -1,8 +1,10 @@
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
from datetime import timedelta
from typing import Optional
__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout']
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
# Default process group wide timeout, if applicable.
# This only applies to the non-nccl backends
@ -16,6 +18,7 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
try:
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
except ImportError:
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.

View File

@ -6,10 +6,9 @@ import threading
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
from torch.utils._typing_utils import not_none
from ..utils._typing_utils import not_none
__all__ = ["init_device_mesh", "DeviceMesh"]

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import torch
import torchvision
import torch
from torch.distributed._tools import MemoryTracker

View File

@ -8,7 +8,7 @@
from torch.distributed.launcher.api import ( # noqa: F401
LaunchConfig,
elastic_launch,
launch_agent,
LaunchConfig,
)

View File

@ -15,13 +15,18 @@ import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException
from torch.distributed.elastic.multiprocessing import (
DefaultLogsSpecs,
LogsSpecs,
SignalException,
)
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
logger = get_logger(__name__)
@ -212,8 +217,8 @@ def launch_agent(
"max_restarts": config.max_restarts,
"monitor_interval": config.monitor_interval,
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
"metrics_cfg": config.metrics_cfg
}
"metrics_cfg": config.metrics_cfg,
},
)
rdzv_parameters = RendezvousParameters(

View File

@ -9,6 +9,7 @@
import logging
from typing import Dict, List
__all__: List[str] = []
_log_handlers: Dict[str, logging.Handler] = {

View File

@ -1,4 +1,7 @@
import torch
from .functional import * # noqa: F403
if torch.distributed.rpc.is_available():
from .api.remote_module import RemoteModule
from .functional import * # noqa: F403

View File

@ -21,14 +21,15 @@ from typing import (
import torch
import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch import device, dtype, nn, Tensor
from torch.distributed import _remote_device
from torch.distributed.nn.jit import instantiator
from torch.distributed.rpc.internal import _internal_rpc_pickler
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
__all__ = ["RemoteModule"]
_grad_t = Union[Tuple[Tensor, ...], Tensor]
@ -120,7 +121,6 @@ def _raise_not_supported(name: str) -> None:
class _RemoteModule(nn.Module):
def __new__(cls, *args, **kwargs):
# Use __new__ for logging purposes.
torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module")
@ -370,7 +370,10 @@ class _RemoteModule(nn.Module):
self,
hook: Union[
Callable[[T, Tuple[Any, ...]], Optional[Any]],
Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],
Callable[
[T, Tuple[Any, ...], Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]],
],
],
prepend: bool = False,
with_kwargs: bool = False,
@ -405,10 +408,7 @@ class _RemoteModule(nn.Module):
)
def named_parameters( # type: ignore[return]
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
_raise_not_supported(self.named_parameters.__name__)
@ -416,10 +416,7 @@ class _RemoteModule(nn.Module):
_raise_not_supported(self.buffers.__name__)
def named_buffers( # type: ignore[return]
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)
@ -464,7 +461,11 @@ class _RemoteModule(nn.Module):
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
remote_device = _remote_device(remote_device_str)
self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank()
self.on = (
remote_device.worker_name()
if remote_device.worker_name() is not None
else remote_device.rank()
)
self.device = str(remote_device.device())
agent = rpc._get_current_rpc_agent()
# If the device map of the remote worker is set,

View File

@ -2,11 +2,13 @@
import torch
import torch.distributed as dist
from torch.autograd import Function
# The two imports below are not always available depending on the
# USE_DISTRIBUTED compile flag. Make sure they raise import error
# if we're trying to use them.
from torch.distributed import group, ReduceOp
def broadcast(tensor, src, group=group.WORLD):
"""
Broadcasts the tensor to the whole group.
@ -116,6 +118,7 @@ def all_gather(tensor, group=group.WORLD):
"""
return _AllGather.apply(group, tensor)
def _all_gather_base(output_tensor, input_tensor, group=group.WORLD):
"""
Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
@ -340,6 +343,7 @@ class _AllGather(Function):
gx = torch.sum(torch.stack(gxs), dim=0)
return (None, gx)
class _AllGatherBase(Function):
@staticmethod
def forward(ctx, output_tensor, input_tensor, group):
@ -354,16 +358,19 @@ class _AllGatherBase(Function):
out_size = list(grad_output.size())
if out_size[0] % world_size != 0:
raise RuntimeError(
f'Tensor with dimensions: {out_size} does '
f'not have first dimension divisible by world_size: {world_size}'
f"Tensor with dimensions: {out_size} does "
f"not have first dimension divisible by world_size: {world_size}"
)
out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype)
gx = torch.empty(
out_size, device=grad_output.device, dtype=grad_output.dtype
)
dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
else:
raise RuntimeError("Backend not supported!")
return (None, gx, None)
class _AlltoAll(Function):
@staticmethod
def forward(ctx, group, out_tensor_list, *tensors):
@ -391,7 +398,9 @@ class _AlltoAll(Function):
@staticmethod
def backward(ctx, *grad_outputs):
tensor_list = [
torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
torch.empty(
size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype
)
for size in ctx.input_tensor_size_list
]
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
@ -415,7 +424,9 @@ class _AlltoAllSingle(Function):
@staticmethod
def backward(ctx, grad_output):
tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype)
tensor = torch.empty(
ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
)
return (None, None, None, None) + (
_AlltoAllSingle.apply(
ctx.group,

View File

@ -5,7 +5,7 @@ import logging
import operator
from collections import defaultdict
from enum import Enum
from inspect import Parameter, signature, Signature
from inspect import Parameter, Signature, signature
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@ -21,6 +21,7 @@ from torch.export.unflatten import (
)
from torch.fx.node import map_aggregate
from torch.fx.passes.split_module import split_module
from ._backward import _null_coalesce_accumulate, stage_backward
from ._unflatten import _outline_submodules
from ._utils import PipeInfo
@ -1176,7 +1177,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
predecessor_module = getattr(predecessor_module, atom)
except AttributeError as e:
raise AttributeError(
f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}'
f"Specified target {qualname} referenced "
f'nonexistent module {".".join(atoms[: i + 1])}'
) from e
mod_to_wrap = getattr(predecessor_module, atoms[-1])

View File

@ -8,6 +8,7 @@ from .schedules import (
)
from .stage import build_stage, PipelineStage
__all__ = [
"Pipe",
"pipe_split",

View File

@ -47,7 +47,7 @@ class _remote_device:
else:
raise ValueError(PARSE_ERROR)
else:
raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
raise TypeError(f"Invalid type for remote_device: {type(remote_device)}")
# Do some basic sanity check (no empty string)
if self._worker_name is not None and not self._worker_name:
@ -96,18 +96,18 @@ class _remote_device:
def __repr__(self):
if self._device is not None:
if self._worker_name is not None:
return f'{self._worker_name}/{self._device}'
return f"{self._worker_name}/{self._device}"
elif self._rank is not None:
return f'rank:{self._rank}/{self._device}'
return f"rank:{self._rank}/{self._device}"
else:
return str(self._device)
else:
if self._worker_name is not None:
return f'{self._worker_name}'
return f"{self._worker_name}"
elif self._rank is not None:
return f'{self._rank}'
return f"{self._rank}"
else:
raise RuntimeError('Invalid state!')
raise RuntimeError("Invalid state!")
def __eq__(self, other):
if not isinstance(other, _remote_device):
@ -122,8 +122,5 @@ class _remote_device:
return False
def __hash__(self):
return hash(self._worker_name) ^ \
hash(self._device) ^ \
hash(self._rank)
return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)

View File

@ -10,7 +10,7 @@ import numbers
import os
import sys
from datetime import timedelta
from typing import Dict, Optional, Callable, Iterator, Tuple
from typing import Callable, Dict, Iterator, Optional, Tuple
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
@ -21,6 +21,7 @@ _rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]]
__all__ = ["register_rendezvous_handler", "rendezvous"]
def register_rendezvous_handler(scheme, handler):
"""
Register a new rendezvous handler.
@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler):
"""
global _rendezvous_handlers
if scheme in _rendezvous_handlers:
raise RuntimeError(
f"Rendezvous handler for {scheme}:// already registered"
)
raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
_rendezvous_handlers[scheme] = handler
# Query will have format "rank=0&world_size=1" and is
# converted into {"rank": 0, "world_size": 1}
def _query_to_dict(query: str) -> Dict[str, str]:
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
return {
pair[0]: pair[1]
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
}
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool:
return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store:
def _create_c10d_store(
hostname, port, rank, world_size, timeout, use_libuv=True
) -> Store:
"""
Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True
else:
start_daemon = rank == 0
return TCPStore(
hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv
hostname,
port,
world_size,
start_daemon,
timeout,
multi_tenant=True,
use_libuv=use_libuv,
)
@ -208,7 +218,9 @@ def _tcp_rendezvous_handler(
assert result.hostname is not None
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv)
store = _create_c10d_store(
result.hostname, result.port, rank, world_size, timeout, use_libuv
)
yield (store, rank, world_size)
@ -250,12 +262,13 @@ def _env_rendezvous_handler(
else:
world_size = int(_get_env_or_raise("WORLD_SIZE"))
master_addr = _get_env_or_raise("MASTER_ADDR")
master_port = int(_get_env_or_raise("MASTER_PORT"))
use_libuv = _get_use_libuv_from_query_dict(query_dict)
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
store = _create_c10d_store(
master_addr, master_port, rank, world_size, timeout, use_libuv
)
yield (store, rank, world_size)

View File

@ -397,9 +397,9 @@ import logging
import os
import sys
import uuid
import importlib.metadata as metadata
from argparse import REMAINDER, ArgumentParser
from typing import Callable, List, Tuple, Type, Union, Optional, Set
from argparse import ArgumentParser, REMAINDER
from importlib import metadata
from typing import Callable, List, Optional, Set, Tuple, Type, Union
import torch
from torch.distributed.argparse_util import check_env, env
@ -408,9 +408,9 @@ from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torch.utils.backend_registration import _get_custom_mod_func
import torch.multiprocessing
logger = get_logger(__name__)
@ -693,21 +693,26 @@ def determine_local_world_size(nproc_per_node: str):
if torch.cuda.is_available():
num_proc = torch.cuda.device_count()
device_type = "gpu"
elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \
_get_custom_mod_func("is_available")():
elif (
hasattr(torch, torch._C._get_privateuse1_backend_name())
and _get_custom_mod_func("is_available")()
):
num_proc = _get_custom_mod_func("device_count")()
device_type = torch._C._get_privateuse1_backend_name()
else:
num_proc = os.cpu_count()
device_type = "cpu"
else:
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
raise ValueError(
f"Unsupported nproc_per_node value: {nproc_per_node}"
) from e
logger.info(
"Using nproc_per_node=%s,"
" setting to %s since the instance "
"has %s %s",
nproc_per_node, num_proc, os.cpu_count(), device_type
"Using nproc_per_node=%s," " setting to %s since the instance " "has %s %s",
nproc_per_node,
num_proc,
os.cpu_count(),
device_type,
)
return num_proc
@ -753,9 +758,13 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
logs_specs_cls = entrypoint_list[0].load()
if logs_specs_cls is None:
raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
raise ValueError(
f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
)
logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
logging.info(
"Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
)
else:
logs_specs_cls = DefaultLogsSpecs
@ -768,7 +777,11 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
assert 0 < min_nodes <= max_nodes
assert args.max_restarts >= 0
if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint:
if (
hasattr(args, "master_addr")
and args.rdzv_backend != "static"
and not args.rdzv_endpoint
):
logger.warning(
"master_addr is only used for static rdzv_backend and when rdzv_endpoint "
"is not specified."
@ -784,7 +797,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
"please further tune the variable for optimal performance in "
"your application as needed. \n"
"*****************************************",
omp_num_threads
omp_num_threads,
)
# This env variable will be passed down to the subprocesses
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
@ -888,7 +901,9 @@ def run(args):
"--rdzv-endpoint=%s "
"--rdzv-id=%s\n"
"**************************************\n",
args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id
args.rdzv_backend,
args.rdzv_endpoint,
args.rdzv_id,
)
config, cmd, cmd_args = config_from_args(args)

View File

@ -21,6 +21,7 @@ from torch.nn.parallel._functions import _get_stream
from torch.nn.parallel.scatter_gather import _is_namedtuple
from torch.nn.utils.rnn import PackedSequence
__all__ = [] # type: ignore[var-annotated]