mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable UFMT for torch/distributed/
(#128870)
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin, https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
e165a5971f
commit
94dc3253a0
@ -1387,33 +1387,6 @@ exclude_patterns = [
|
||||
'torch/contrib/_tensorboard_vis.py',
|
||||
"torch/cuda/_gpu_trace.py",
|
||||
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
|
||||
'torch/distributed/__init__.py',
|
||||
'torch/distributed/_composable_state.py',
|
||||
'torch/distributed/_sharded_tensor/__init__.py',
|
||||
'torch/distributed/_sharding_spec/__init__.py',
|
||||
'torch/distributed/_tools/__init__.py',
|
||||
'torch/distributed/_tools/memory_tracker.py',
|
||||
'torch/distributed/argparse_util.py',
|
||||
'torch/distributed/c10d_logger.py',
|
||||
'torch/distributed/collective_utils.py',
|
||||
'torch/distributed/constants.py',
|
||||
'torch/distributed/distributed_c10d.py',
|
||||
'torch/distributed/examples/memory_tracker_example.py',
|
||||
'torch/distributed/launch.py',
|
||||
'torch/distributed/launcher/__init__.py',
|
||||
'torch/distributed/launcher/api.py',
|
||||
'torch/distributed/logging_handlers.py',
|
||||
'torch/distributed/nn/__init__.py',
|
||||
'torch/distributed/nn/api/__init__.py',
|
||||
'torch/distributed/nn/api/remote_module.py',
|
||||
'torch/distributed/nn/functional.py',
|
||||
'torch/distributed/nn/jit/__init__.py',
|
||||
'torch/distributed/nn/jit/instantiator.py',
|
||||
'torch/distributed/nn/jit/templates/__init__.py',
|
||||
'torch/distributed/nn/jit/templates/remote_module_template.py',
|
||||
'torch/distributed/remote_device.py',
|
||||
'torch/distributed/rendezvous.py',
|
||||
'torch/distributed/run.py',
|
||||
'torch/fft/__init__.py',
|
||||
'torch/func/__init__.py',
|
||||
'torch/futures/__init__.py',
|
||||
|
@ -1,12 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
import pdb
|
||||
import io
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
"""
|
||||
Return ``True`` if the distributed package is available.
|
||||
@ -32,31 +30,31 @@ DistStoreError = torch._C._DistStoreError
|
||||
|
||||
if is_available():
|
||||
from torch._C._distributed_c10d import (
|
||||
Store,
|
||||
FileStore,
|
||||
TCPStore,
|
||||
ProcessGroup as ProcessGroup,
|
||||
Backend as _Backend,
|
||||
PrefixStore,
|
||||
Reducer,
|
||||
Logger,
|
||||
BuiltinCommHookType,
|
||||
GradBucket,
|
||||
Work as _Work,
|
||||
_DEFAULT_FIRST_BUCKET_BYTES,
|
||||
_register_comm_hook,
|
||||
_register_builtin_comm_hook,
|
||||
_broadcast_coalesced,
|
||||
_compute_bucket_assignment_by_size,
|
||||
_verify_params_across_processes,
|
||||
_ControlCollectives,
|
||||
_DEFAULT_FIRST_BUCKET_BYTES,
|
||||
_make_nccl_premul_sum,
|
||||
_register_builtin_comm_hook,
|
||||
_register_comm_hook,
|
||||
_StoreCollectives,
|
||||
_test_python_store,
|
||||
_verify_params_across_processes,
|
||||
Backend as _Backend,
|
||||
BuiltinCommHookType,
|
||||
DebugLevel,
|
||||
FileStore,
|
||||
get_debug_level,
|
||||
GradBucket,
|
||||
Logger,
|
||||
PrefixStore,
|
||||
ProcessGroup as ProcessGroup,
|
||||
Reducer,
|
||||
set_debug_level,
|
||||
set_debug_level_from_env,
|
||||
_make_nccl_premul_sum,
|
||||
_ControlCollectives,
|
||||
_StoreCollectives,
|
||||
Store,
|
||||
TCPStore,
|
||||
Work as _Work,
|
||||
)
|
||||
|
||||
class _DistributedPdb(pdb.Pdb):
|
||||
@ -66,10 +64,11 @@ if is_available():
|
||||
Usage:
|
||||
_DistributedPdb().set_trace()
|
||||
"""
|
||||
|
||||
def interaction(self, *args, **kwargs):
|
||||
_stdin = sys.stdin
|
||||
try:
|
||||
sys.stdin = open('/dev/stdin')
|
||||
sys.stdin = open("/dev/stdin")
|
||||
pdb.Pdb.interaction(self, *args, **kwargs)
|
||||
finally:
|
||||
sys.stdin = _stdin
|
||||
@ -101,37 +100,31 @@ if is_available():
|
||||
del guard
|
||||
|
||||
if sys.platform != "win32":
|
||||
from torch._C._distributed_c10d import (
|
||||
HashStore,
|
||||
_round_robin_process_groups,
|
||||
)
|
||||
from torch._C._distributed_c10d import _round_robin_process_groups, HashStore
|
||||
|
||||
from .distributed_c10d import * # noqa: F403
|
||||
from .device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
# Variables prefixed with underscore are not auto imported
|
||||
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
|
||||
# this.
|
||||
|
||||
from .distributed_c10d import * # noqa: F403
|
||||
from .distributed_c10d import (
|
||||
_all_gather_base,
|
||||
_reduce_scatter_base,
|
||||
_create_process_group_wrapper,
|
||||
_rank_not_in_group,
|
||||
_coalescing_manager,
|
||||
_CoalescingManager,
|
||||
_create_process_group_wrapper,
|
||||
_get_process_group_name,
|
||||
_rank_not_in_group,
|
||||
_reduce_scatter_base,
|
||||
get_node_local_rank,
|
||||
)
|
||||
|
||||
from .remote_device import _remote_device
|
||||
from .rendezvous import (
|
||||
rendezvous,
|
||||
_create_store_from_options,
|
||||
register_rendezvous_handler,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
from .remote_device import _remote_device
|
||||
from .device_mesh import init_device_mesh, DeviceMesh
|
||||
|
||||
set_debug_level_from_env()
|
||||
|
||||
else:
|
||||
@ -143,4 +136,5 @@ else:
|
||||
|
||||
class _ProcessGroupStub:
|
||||
pass
|
||||
|
||||
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]
|
||||
|
@ -5,6 +5,7 @@ import torch._dynamo.compiled_autograd as ca
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.distributed_c10d import ReduceOp
|
||||
|
||||
from ._fsdp_common import (
|
||||
_get_dim0_padded_size,
|
||||
_raise_assert_with_print,
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import traceback
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, cast, List, Optional
|
||||
|
@ -4,10 +4,10 @@ from typing import List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
||||
from torch.distributed.device_mesh import _get_device_handle
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo
|
||||
from ._fsdp_state import _get_module_fsdp_state
|
||||
|
||||
|
@ -7,12 +7,12 @@ from typing import Any, cast, List, Optional, Sequence, Tuple
|
||||
import torch
|
||||
import torch._dynamo.compiled_autograd as ca
|
||||
import torch.nn as nn
|
||||
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta
|
||||
|
||||
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
||||
from ._fsdp_common import (
|
||||
_chunk_with_empty,
|
||||
@ -24,6 +24,7 @@ from ._fsdp_common import (
|
||||
HSDPMeshInfo,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
[Note: FSDP tensors]
|
||||
FSDP considers the following tensors:
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
@ -12,6 +11,7 @@ from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicat
|
||||
from torch.profiler import record_function
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
|
||||
from ._fsdp_collectives import (
|
||||
AllGatherResult,
|
||||
@ -22,6 +22,7 @@ from ._fsdp_collectives import (
|
||||
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState
|
||||
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
||||
|
||||
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
|
||||
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@ -14,13 +13,16 @@ from torch.distributed._composable_state import (
|
||||
)
|
||||
from torch.distributed.utils import _to_kwargs
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
from ._fsdp_api import MixedPrecisionPolicy
|
||||
from ._fsdp_common import _cast_fp_tensor, TrainingState
|
||||
from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._fsdp_param import FSDPParam
|
||||
|
||||
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
|
||||
from typing import Any, cast, Iterable, List, NoReturn, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -8,7 +8,6 @@ from torch.distributed._composable.contract import contract
|
||||
from torch.distributed._composable_state import _get_module_state, _insert_module_state
|
||||
from torch.distributed.fsdp._common_utils import _FSDPState
|
||||
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
|
||||
|
||||
from torch.distributed.fsdp._init_utils import (
|
||||
_init_buffer_state,
|
||||
_init_core_state,
|
||||
|
@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from .contract import _get_registry, contract
|
||||
|
||||
|
||||
_ROOT_MODULE_PREFIX = ""
|
||||
|
||||
|
||||
|
@ -1,15 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._distributed_c10d import _DistributedBackendOptions, Backend
|
||||
|
||||
|
@ -11,6 +11,7 @@ from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
|
||||
|
||||
from . import _functional_collectives_impl as fun_col_impl
|
||||
|
||||
|
||||
try:
|
||||
from torch.utils._cxx_pytree import tree_map_only
|
||||
except ImportError:
|
||||
@ -1134,6 +1135,7 @@ from torch.distributed.distributed_c10d import (
|
||||
reduce_scatter_tensor as legacy_reducescatter,
|
||||
)
|
||||
|
||||
|
||||
# This dict should contain sets of functions that dynamo is allowed to remap.
|
||||
# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
|
||||
traceable_collective_remaps = {
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
|
||||
"""
|
||||
This file contains the op impls for the legacy (c10d_functional) functional collectives.
|
||||
These impls simply call into the native (_c10d_functional) functional collectives.
|
||||
|
@ -1,11 +1,12 @@
|
||||
# Keep old package for BC purposes, this file should be removed once
|
||||
# everything moves to the `torch.distributed._shard` package.
|
||||
import sys
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import * # noqa: F403
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
@ -15,4 +16,6 @@ with warnings.catch_warnings():
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor
|
||||
sys.modules[
|
||||
"torch.distributed._sharded_tensor"
|
||||
] = torch.distributed._shard.sharded_tensor
|
||||
|
@ -1,11 +1,12 @@
|
||||
# Keep old package for BC purposes, this file should be removed once
|
||||
# everything moves to the `torch.distributed._shard` package.
|
||||
import sys
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharding_spec import * # noqa: F403
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
@ -16,4 +17,6 @@ with warnings.catch_warnings():
|
||||
)
|
||||
|
||||
import torch.distributed._shard.sharding_spec as _sharding_spec
|
||||
sys.modules['torch.distributed._sharding_spec'] = _sharding_spec
|
||||
|
||||
|
||||
sys.modules["torch.distributed._sharding_spec"] = _sharding_spec
|
||||
|
@ -22,6 +22,7 @@ import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
||||
|
||||
if dist.is_available() or TYPE_CHECKING:
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .mem_tracker import MemTracker
|
||||
from .memory_tracker import MemoryTracker
|
||||
from .mod_tracker import ModTracker
|
||||
from .mem_tracker import MemTracker
|
||||
|
@ -1,24 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from collections import defaultdict
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
no_type_check,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
import operator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
@ -234,6 +224,7 @@ class MemoryTracker:
|
||||
|
||||
def _create_pre_forward_hook(self, name: str) -> Callable:
|
||||
"""Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start."""
|
||||
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
|
||||
self._cur_module_name = f"{name}.forward"
|
||||
if (
|
||||
|
@ -15,9 +15,9 @@ from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.distributed.logging_handlers import _log_handlers
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
_DEFAULT_DESTINATION = "default"
|
||||
@ -36,7 +36,9 @@ def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Lo
|
||||
return logger
|
||||
|
||||
|
||||
def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[logging.Handler, str]:
|
||||
def _get_logging_handler(
|
||||
destination: str = _DEFAULT_DESTINATION,
|
||||
) -> Tuple[logging.Handler, str]:
|
||||
log_handler = _log_handlers[destination]
|
||||
log_handler_name = type(log_handler).__name__
|
||||
return (log_handler, log_handler_name)
|
||||
@ -69,8 +71,10 @@ def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
|
||||
}
|
||||
return msg_dict
|
||||
|
||||
_T = TypeVar('_T')
|
||||
_P = ParamSpec('_P')
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
@functools.wraps(func)
|
||||
|
@ -14,8 +14,10 @@ from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar,
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncPayload(Generic[T]):
|
||||
stage_name: Optional[str]
|
||||
@ -23,6 +25,7 @@ class SyncPayload(Generic[T]):
|
||||
payload: T
|
||||
exception: Optional[Exception] = None
|
||||
|
||||
|
||||
def broadcast(
|
||||
data_or_fn: Union[T, Callable[[], T]],
|
||||
*,
|
||||
@ -55,10 +58,12 @@ def broadcast(
|
||||
"""
|
||||
|
||||
if not success and data_or_fn is not None:
|
||||
raise AssertionError("Data or Function is expected to be None if not successful")
|
||||
raise AssertionError(
|
||||
"Data or Function is expected to be None if not successful"
|
||||
)
|
||||
|
||||
payload: Optional[T] = None
|
||||
exception : Optional[Exception] = None
|
||||
exception: Optional[Exception] = None
|
||||
# if no pg is passed then execute if rank is 0
|
||||
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
|
||||
# determine if it is an executable function or data payload only
|
||||
@ -119,7 +124,7 @@ def all_gather(
|
||||
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
|
||||
"""
|
||||
payload: Optional[T] = None
|
||||
exception : Optional[Exception] = None
|
||||
exception: Optional[Exception] = None
|
||||
success = True
|
||||
# determine if it is an executable function or data payload only
|
||||
if callable(data_or_fn):
|
||||
@ -161,7 +166,8 @@ def all_gather(
|
||||
|
||||
if len(exception_list) > 0:
|
||||
raise RuntimeError( # type: ignore[misc]
|
||||
error_msg, exception_list) from exception_list[0]
|
||||
error_msg, exception_list
|
||||
) from exception_list[0]
|
||||
return ret_list
|
||||
else:
|
||||
if not sync_obj.success:
|
||||
|
@ -1,8 +1,10 @@
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout']
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||
|
||||
|
||||
__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"]
|
||||
|
||||
# Default process group wide timeout, if applicable.
|
||||
# This only applies to the non-nccl backends
|
||||
@ -16,6 +18,7 @@ default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT
|
||||
|
||||
try:
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT
|
||||
|
||||
default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT
|
||||
except ImportError:
|
||||
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
|
||||
|
@ -6,10 +6,9 @@ import threading
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed import is_available
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
from ..utils._typing_utils import not_none
|
||||
|
||||
__all__ = ["init_device_mesh", "DeviceMesh"]
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import torch
|
||||
from torch.distributed._tools import MemoryTracker
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
|
||||
from torch.distributed.launcher.api import ( # noqa: F401
|
||||
LaunchConfig,
|
||||
elastic_launch,
|
||||
launch_agent,
|
||||
LaunchConfig,
|
||||
)
|
||||
|
@ -15,13 +15,18 @@ import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
from torch.distributed.elastic import events, metrics
|
||||
from torch.distributed.elastic.agent.server.api import WorkerSpec
|
||||
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
|
||||
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException
|
||||
from torch.distributed.elastic.multiprocessing import (
|
||||
DefaultLogsSpecs,
|
||||
LogsSpecs,
|
||||
SignalException,
|
||||
)
|
||||
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
||||
from torch.distributed.elastic.rendezvous import RendezvousParameters
|
||||
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
|
||||
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
|
||||
|
||||
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -212,8 +217,8 @@ def launch_agent(
|
||||
"max_restarts": config.max_restarts,
|
||||
"monitor_interval": config.monitor_interval,
|
||||
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
|
||||
"metrics_cfg": config.metrics_cfg
|
||||
}
|
||||
"metrics_cfg": config.metrics_cfg,
|
||||
},
|
||||
)
|
||||
|
||||
rdzv_parameters = RendezvousParameters(
|
||||
|
@ -9,6 +9,7 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
_log_handlers: Dict[str, logging.Handler] = {
|
||||
|
@ -1,4 +1,7 @@
|
||||
import torch
|
||||
|
||||
from .functional import * # noqa: F403
|
||||
|
||||
|
||||
if torch.distributed.rpc.is_available():
|
||||
from .api.remote_module import RemoteModule
|
||||
from .functional import * # noqa: F403
|
||||
|
@ -21,14 +21,15 @@ from typing import (
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from torch.distributed.nn.jit import instantiator
|
||||
from torch import device, dtype, nn, Tensor
|
||||
from torch.distributed import _remote_device
|
||||
from torch.distributed.nn.jit import instantiator
|
||||
from torch.distributed.rpc.internal import _internal_rpc_pickler
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
__all__ = ["RemoteModule"]
|
||||
|
||||
_grad_t = Union[Tuple[Tensor, ...], Tensor]
|
||||
@ -120,7 +121,6 @@ def _raise_not_supported(name: str) -> None:
|
||||
|
||||
|
||||
class _RemoteModule(nn.Module):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Use __new__ for logging purposes.
|
||||
torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module")
|
||||
@ -370,7 +370,10 @@ class _RemoteModule(nn.Module):
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, Tuple[Any, ...]], Optional[Any]],
|
||||
Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]],
|
||||
Callable[
|
||||
[T, Tuple[Any, ...], Dict[str, Any]],
|
||||
Optional[Tuple[Any, Dict[str, Any]]],
|
||||
],
|
||||
],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
@ -405,10 +408,7 @@ class _RemoteModule(nn.Module):
|
||||
)
|
||||
|
||||
def named_parameters( # type: ignore[return]
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Parameter]]:
|
||||
_raise_not_supported(self.named_parameters.__name__)
|
||||
|
||||
@ -416,10 +416,7 @@ class _RemoteModule(nn.Module):
|
||||
_raise_not_supported(self.buffers.__name__)
|
||||
|
||||
def named_buffers( # type: ignore[return]
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_buffers.__name__)
|
||||
|
||||
@ -464,7 +461,11 @@ class _RemoteModule(nn.Module):
|
||||
assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
|
||||
|
||||
remote_device = _remote_device(remote_device_str)
|
||||
self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank()
|
||||
self.on = (
|
||||
remote_device.worker_name()
|
||||
if remote_device.worker_name() is not None
|
||||
else remote_device.rank()
|
||||
)
|
||||
self.device = str(remote_device.device())
|
||||
agent = rpc._get_current_rpc_agent()
|
||||
# If the device map of the remote worker is set,
|
||||
|
@ -2,11 +2,13 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
|
||||
# The two imports below are not always available depending on the
|
||||
# USE_DISTRIBUTED compile flag. Make sure they raise import error
|
||||
# if we're trying to use them.
|
||||
from torch.distributed import group, ReduceOp
|
||||
|
||||
|
||||
def broadcast(tensor, src, group=group.WORLD):
|
||||
"""
|
||||
Broadcasts the tensor to the whole group.
|
||||
@ -116,6 +118,7 @@ def all_gather(tensor, group=group.WORLD):
|
||||
"""
|
||||
return _AllGather.apply(group, tensor)
|
||||
|
||||
|
||||
def _all_gather_base(output_tensor, input_tensor, group=group.WORLD):
|
||||
"""
|
||||
Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
|
||||
@ -340,6 +343,7 @@ class _AllGather(Function):
|
||||
gx = torch.sum(torch.stack(gxs), dim=0)
|
||||
return (None, gx)
|
||||
|
||||
|
||||
class _AllGatherBase(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, output_tensor, input_tensor, group):
|
||||
@ -354,16 +358,19 @@ class _AllGatherBase(Function):
|
||||
out_size = list(grad_output.size())
|
||||
if out_size[0] % world_size != 0:
|
||||
raise RuntimeError(
|
||||
f'Tensor with dimensions: {out_size} does '
|
||||
f'not have first dimension divisible by world_size: {world_size}'
|
||||
f"Tensor with dimensions: {out_size} does "
|
||||
f"not have first dimension divisible by world_size: {world_size}"
|
||||
)
|
||||
out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
|
||||
gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype)
|
||||
gx = torch.empty(
|
||||
out_size, device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
|
||||
else:
|
||||
raise RuntimeError("Backend not supported!")
|
||||
return (None, gx, None)
|
||||
|
||||
|
||||
class _AlltoAll(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, group, out_tensor_list, *tensors):
|
||||
@ -391,7 +398,9 @@ class _AlltoAll(Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
tensor_list = [
|
||||
torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
|
||||
torch.empty(
|
||||
size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype
|
||||
)
|
||||
for size in ctx.input_tensor_size_list
|
||||
]
|
||||
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
|
||||
@ -415,7 +424,9 @@ class _AlltoAllSingle(Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype)
|
||||
tensor = torch.empty(
|
||||
ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
return (None, None, None, None) + (
|
||||
_AlltoAllSingle.apply(
|
||||
ctx.group,
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from inspect import Parameter, signature, Signature
|
||||
from inspect import Parameter, Signature, signature
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@ -21,6 +21,7 @@ from torch.export.unflatten import (
|
||||
)
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.fx.passes.split_module import split_module
|
||||
|
||||
from ._backward import _null_coalesce_accumulate, stage_backward
|
||||
from ._unflatten import _outline_submodules
|
||||
from ._utils import PipeInfo
|
||||
@ -1176,7 +1177,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
|
||||
predecessor_module = getattr(predecessor_module, atom)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}'
|
||||
f"Specified target {qualname} referenced "
|
||||
f'nonexistent module {".".join(atoms[: i + 1])}'
|
||||
) from e
|
||||
|
||||
mod_to_wrap = getattr(predecessor_module, atoms[-1])
|
||||
|
@ -8,6 +8,7 @@ from .schedules import (
|
||||
)
|
||||
from .stage import build_stage, PipelineStage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Pipe",
|
||||
"pipe_split",
|
||||
|
@ -47,7 +47,7 @@ class _remote_device:
|
||||
else:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
else:
|
||||
raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
|
||||
raise TypeError(f"Invalid type for remote_device: {type(remote_device)}")
|
||||
|
||||
# Do some basic sanity check (no empty string)
|
||||
if self._worker_name is not None and not self._worker_name:
|
||||
@ -96,18 +96,18 @@ class _remote_device:
|
||||
def __repr__(self):
|
||||
if self._device is not None:
|
||||
if self._worker_name is not None:
|
||||
return f'{self._worker_name}/{self._device}'
|
||||
return f"{self._worker_name}/{self._device}"
|
||||
elif self._rank is not None:
|
||||
return f'rank:{self._rank}/{self._device}'
|
||||
return f"rank:{self._rank}/{self._device}"
|
||||
else:
|
||||
return str(self._device)
|
||||
else:
|
||||
if self._worker_name is not None:
|
||||
return f'{self._worker_name}'
|
||||
return f"{self._worker_name}"
|
||||
elif self._rank is not None:
|
||||
return f'{self._rank}'
|
||||
return f"{self._rank}"
|
||||
else:
|
||||
raise RuntimeError('Invalid state!')
|
||||
raise RuntimeError("Invalid state!")
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _remote_device):
|
||||
@ -122,8 +122,5 @@ class _remote_device:
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._worker_name) ^ \
|
||||
hash(self._device) ^ \
|
||||
hash(self._rank)
|
||||
return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)
|
||||
|
@ -10,7 +10,7 @@ import numbers
|
||||
import os
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional, Callable, Iterator, Tuple
|
||||
from typing import Callable, Dict, Iterator, Optional, Tuple
|
||||
|
||||
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
|
||||
|
||||
@ -21,6 +21,7 @@ _rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]]
|
||||
|
||||
__all__ = ["register_rendezvous_handler", "rendezvous"]
|
||||
|
||||
|
||||
def register_rendezvous_handler(scheme, handler):
|
||||
"""
|
||||
Register a new rendezvous handler.
|
||||
@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler):
|
||||
"""
|
||||
global _rendezvous_handlers
|
||||
if scheme in _rendezvous_handlers:
|
||||
raise RuntimeError(
|
||||
f"Rendezvous handler for {scheme}:// already registered"
|
||||
)
|
||||
raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered")
|
||||
_rendezvous_handlers[scheme] = handler
|
||||
|
||||
|
||||
# Query will have format "rank=0&world_size=1" and is
|
||||
# converted into {"rank": 0, "world_size": 1}
|
||||
def _query_to_dict(query: str) -> Dict[str, str]:
|
||||
return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))}
|
||||
return {
|
||||
pair[0]: pair[1]
|
||||
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
|
||||
}
|
||||
|
||||
|
||||
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
|
||||
@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool:
|
||||
return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True)
|
||||
|
||||
|
||||
def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store:
|
||||
def _create_c10d_store(
|
||||
hostname, port, rank, world_size, timeout, use_libuv=True
|
||||
) -> Store:
|
||||
"""
|
||||
Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store.
|
||||
|
||||
@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True
|
||||
else:
|
||||
start_daemon = rank == 0
|
||||
return TCPStore(
|
||||
hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv
|
||||
hostname,
|
||||
port,
|
||||
world_size,
|
||||
start_daemon,
|
||||
timeout,
|
||||
multi_tenant=True,
|
||||
use_libuv=use_libuv,
|
||||
)
|
||||
|
||||
|
||||
@ -208,7 +218,9 @@ def _tcp_rendezvous_handler(
|
||||
|
||||
assert result.hostname is not None
|
||||
|
||||
store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv)
|
||||
store = _create_c10d_store(
|
||||
result.hostname, result.port, rank, world_size, timeout, use_libuv
|
||||
)
|
||||
|
||||
yield (store, rank, world_size)
|
||||
|
||||
@ -250,12 +262,13 @@ def _env_rendezvous_handler(
|
||||
else:
|
||||
world_size = int(_get_env_or_raise("WORLD_SIZE"))
|
||||
|
||||
|
||||
master_addr = _get_env_or_raise("MASTER_ADDR")
|
||||
master_port = int(_get_env_or_raise("MASTER_PORT"))
|
||||
use_libuv = _get_use_libuv_from_query_dict(query_dict)
|
||||
|
||||
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
|
||||
store = _create_c10d_store(
|
||||
master_addr, master_port, rank, world_size, timeout, use_libuv
|
||||
)
|
||||
|
||||
yield (store, rank, world_size)
|
||||
|
||||
|
@ -397,9 +397,9 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import importlib.metadata as metadata
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
from typing import Callable, List, Tuple, Type, Union, Optional, Set
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from importlib import metadata
|
||||
from typing import Callable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
@ -408,9 +408,9 @@ from torch.distributed.elastic.multiprocessing.errors import record
|
||||
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
|
||||
from torch.distributed.elastic.utils import macros
|
||||
from torch.distributed.elastic.utils.logging import get_logger
|
||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from torch.utils.backend_registration import _get_custom_mod_func
|
||||
import torch.multiprocessing
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -693,21 +693,26 @@ def determine_local_world_size(nproc_per_node: str):
|
||||
if torch.cuda.is_available():
|
||||
num_proc = torch.cuda.device_count()
|
||||
device_type = "gpu"
|
||||
elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \
|
||||
_get_custom_mod_func("is_available")():
|
||||
elif (
|
||||
hasattr(torch, torch._C._get_privateuse1_backend_name())
|
||||
and _get_custom_mod_func("is_available")()
|
||||
):
|
||||
num_proc = _get_custom_mod_func("device_count")()
|
||||
device_type = torch._C._get_privateuse1_backend_name()
|
||||
else:
|
||||
num_proc = os.cpu_count()
|
||||
device_type = "cpu"
|
||||
else:
|
||||
raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
|
||||
raise ValueError(
|
||||
f"Unsupported nproc_per_node value: {nproc_per_node}"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
"Using nproc_per_node=%s,"
|
||||
" setting to %s since the instance "
|
||||
"has %s %s",
|
||||
nproc_per_node, num_proc, os.cpu_count(), device_type
|
||||
"Using nproc_per_node=%s," " setting to %s since the instance " "has %s %s",
|
||||
nproc_per_node,
|
||||
num_proc,
|
||||
os.cpu_count(),
|
||||
device_type,
|
||||
)
|
||||
return num_proc
|
||||
|
||||
@ -753,9 +758,13 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
||||
logs_specs_cls = entrypoint_list[0].load()
|
||||
|
||||
if logs_specs_cls is None:
|
||||
raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key")
|
||||
raise ValueError(
|
||||
f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
|
||||
)
|
||||
|
||||
logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls))
|
||||
logging.info(
|
||||
"Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
|
||||
)
|
||||
else:
|
||||
logs_specs_cls = DefaultLogsSpecs
|
||||
|
||||
@ -768,7 +777,11 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
|
||||
assert 0 < min_nodes <= max_nodes
|
||||
assert args.max_restarts >= 0
|
||||
|
||||
if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint:
|
||||
if (
|
||||
hasattr(args, "master_addr")
|
||||
and args.rdzv_backend != "static"
|
||||
and not args.rdzv_endpoint
|
||||
):
|
||||
logger.warning(
|
||||
"master_addr is only used for static rdzv_backend and when rdzv_endpoint "
|
||||
"is not specified."
|
||||
@ -784,7 +797,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
|
||||
"please further tune the variable for optimal performance in "
|
||||
"your application as needed. \n"
|
||||
"*****************************************",
|
||||
omp_num_threads
|
||||
omp_num_threads,
|
||||
)
|
||||
# This env variable will be passed down to the subprocesses
|
||||
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
|
||||
@ -888,7 +901,9 @@ def run(args):
|
||||
"--rdzv-endpoint=%s "
|
||||
"--rdzv-id=%s\n"
|
||||
"**************************************\n",
|
||||
args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id
|
||||
args.rdzv_backend,
|
||||
args.rdzv_endpoint,
|
||||
args.rdzv_id,
|
||||
)
|
||||
|
||||
config, cmd, cmd_args = config_from_args(args)
|
||||
|
@ -21,6 +21,7 @@ from torch.nn.parallel._functions import _get_stream
|
||||
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user