mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279bc179a6bb63f0226e24ab42a07b394. Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
This commit is contained in:
@ -1,10 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Dict, Iterator, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -70,10 +69,10 @@ class ActivationWrapper(torch.nn.Module, ABC):
|
||||
@staticmethod
|
||||
def _post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: dict[str, Any],
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
_post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
|
||||
|
||||
@ -88,7 +87,7 @@ class ActivationWrapper(torch.nn.Module, ABC):
|
||||
@staticmethod
|
||||
def _pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: dict[str, Any],
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
|
||||
@ -14,7 +15,7 @@ from torch.optim import Optimizer
|
||||
|
||||
|
||||
# Contains the mappings between the regular and overlapped optimizer types.
|
||||
_registered_overlapped_optims: dict[type, type] = {}
|
||||
_registered_overlapped_optims: Dict[Type, Type] = {}
|
||||
|
||||
|
||||
def register_overlapped(optim_cls):
|
||||
@ -32,7 +33,7 @@ def register_overlapped(optim_cls):
|
||||
|
||||
|
||||
class OverlappedOptimizer(ABC):
|
||||
def __init__(self, optim_cls: type) -> None:
|
||||
def __init__(self, optim_cls: Type) -> None:
|
||||
"""
|
||||
Initialize the OverlappedOptimizer.
|
||||
|
||||
@ -60,7 +61,7 @@ class OverlappedOptimizer(ABC):
|
||||
class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
||||
"""Overlaps a regular ``Optimizer``."""
|
||||
|
||||
def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None:
|
||||
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
|
||||
super().__init__(optim_cls)
|
||||
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
|
||||
self._opt_hook_state = _OptimizerHookState(f_optim, params)
|
||||
@ -81,7 +82,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
||||
)
|
||||
|
||||
|
||||
def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs):
|
||||
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
|
||||
"""Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
|
||||
for clz in inspect.getmro(optim_cls):
|
||||
try:
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import weakref
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -46,7 +46,7 @@ def _perform_local_step(
|
||||
# expects `None` in a list position to indicate that the corresponding
|
||||
# parameter should not be updated
|
||||
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
|
||||
gradients: list[Optional[torch.Tensor]] = [
|
||||
gradients: List[Optional[torch.Tensor]] = [
|
||||
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
||||
]
|
||||
assert (
|
||||
|
@ -1,14 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, no_type_check
|
||||
from typing import Any, Callable, List, no_type_check
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
__all__: List[str] = []
|
||||
|
||||
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -251,9 +252,9 @@ class PowerSGDState:
|
||||
self.rng = np.random.RandomState(random_seed)
|
||||
# Since there is only a single state instance for all the input buckets,
|
||||
# need to maintain a dictionary that maps each bucket index to the local error.
|
||||
self.error_dict: dict[int, torch.Tensor] = {}
|
||||
self.p_memory_dict: dict[int, torch.Tensor] = {}
|
||||
self.q_memory_dict: dict[int, torch.Tensor] = {}
|
||||
self.error_dict: Dict[int, torch.Tensor] = {}
|
||||
self.p_memory_dict: Dict[int, torch.Tensor] = {}
|
||||
self.q_memory_dict: Dict[int, torch.Tensor] = {}
|
||||
# Iteration/step in the training loop.
|
||||
self.iter = 0
|
||||
# Compression stats accumulators
|
||||
|
@ -2,7 +2,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, NamedTuple, Optional
|
||||
from typing import Any, List, NamedTuple, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -165,7 +165,7 @@ class Join:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
joinables: list[Joinable],
|
||||
joinables: List[Joinable],
|
||||
enable: bool = True,
|
||||
throw_on_early_termination: bool = False,
|
||||
**kwargs,
|
||||
@ -228,7 +228,7 @@ class Join:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
type: Optional[type[BaseException]],
|
||||
type: Optional[Type[BaseException]],
|
||||
value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
):
|
||||
|
@ -1,8 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -108,7 +107,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
@ -3,8 +3,7 @@
|
||||
import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -160,7 +159,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
|
||||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
@ -1,8 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa C101
|
||||
import itertools
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Union
|
||||
from typing import Dict, Iterable, Iterator, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -52,7 +51,7 @@ def average_parameters(
|
||||
|
||||
|
||||
def get_params_to_average(
|
||||
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
|
||||
params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
|
||||
):
|
||||
"""
|
||||
Return a list of parameters that need to average.
|
||||
@ -82,7 +81,7 @@ def get_params_to_average(
|
||||
|
||||
def average_parameters_or_parameter_groups(
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
|
Reference in New Issue
Block a user