Revert "PEP585 update - torch/distributed (#145164)"

This reverts commit 6cb186e279bc179a6bb63f0226e24ab42a07b394.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
This commit is contained in:
PyTorch MergeBot
2025-01-20 16:46:46 +00:00
parent 57b2b64acf
commit 6374332d33
79 changed files with 861 additions and 806 deletions

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterator
from enum import auto, Enum
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Iterator, Optional
import torch
import torch.nn as nn
@ -70,10 +69,10 @@ class ActivationWrapper(torch.nn.Module, ABC):
@staticmethod
def _post_state_dict_hook(
module: nn.Module,
state_dict: dict[str, Any],
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
@ -88,7 +87,7 @@ class ActivationWrapper(torch.nn.Module, ABC):
@staticmethod
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: dict[str, Any],
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import inspect
from abc import ABC, abstractmethod
from typing import Dict, Type
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
@ -14,7 +15,7 @@ from torch.optim import Optimizer
# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: dict[type, type] = {}
_registered_overlapped_optims: Dict[Type, Type] = {}
def register_overlapped(optim_cls):
@ -32,7 +33,7 @@ def register_overlapped(optim_cls):
class OverlappedOptimizer(ABC):
def __init__(self, optim_cls: type) -> None:
def __init__(self, optim_cls: Type) -> None:
"""
Initialize the OverlappedOptimizer.
@ -60,7 +61,7 @@ class OverlappedOptimizer(ABC):
class _OverlappedStandardOptimizer(OverlappedOptimizer):
"""Overlaps a regular ``Optimizer``."""
def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None:
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
super().__init__(optim_cls)
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
@ -81,7 +82,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
)
def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs):
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
"""Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
for clz in inspect.getmro(optim_cls):
try:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import weakref
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional
import torch
import torch.distributed as dist
@ -46,7 +46,7 @@ def _perform_local_step(
# expects `None` in a list position to indicate that the corresponding
# parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients: list[Optional[torch.Tensor]] = [
gradients: List[Optional[torch.Tensor]] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
]
assert (

View File

@ -1,14 +1,14 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, no_type_check
from typing import Any, Callable, List, no_type_check
import torch
import torch.distributed as dist
from torch.autograd import Variable
__all__: list[str] = []
__all__: List[str] = []
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"

View File

@ -2,6 +2,7 @@
import logging
import math
from collections import defaultdict
from typing import Dict
import torch
import torch.distributed as dist
@ -251,9 +252,9 @@ class PowerSGDState:
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict: dict[int, torch.Tensor] = {}
self.p_memory_dict: dict[int, torch.Tensor] = {}
self.q_memory_dict: dict[int, torch.Tensor] = {}
self.error_dict: Dict[int, torch.Tensor] = {}
self.p_memory_dict: Dict[int, torch.Tensor] = {}
self.q_memory_dict: Dict[int, torch.Tensor] = {}
# Iteration/step in the training loop.
self.iter = 0
# Compression stats accumulators

View File

@ -2,7 +2,7 @@
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, NamedTuple, Optional
from typing import Any, List, NamedTuple, Optional, Type
import torch
import torch.distributed as dist
@ -165,7 +165,7 @@ class Join:
def __init__(
self,
joinables: list[Joinable],
joinables: List[Joinable],
enable: bool = True,
throw_on_early_termination: bool = False,
**kwargs,
@ -228,7 +228,7 @@ class Join:
def __exit__(
self,
type: Optional[type[BaseException]],
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
):

View File

@ -1,8 +1,7 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Optional, Union
from typing import Dict, Iterable, Optional, Union
import torch
import torch.distributed as dist
@ -108,7 +107,7 @@ class PeriodicModelAverager(ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
):
"""

View File

@ -3,8 +3,7 @@
import logging
import warnings
from collections import OrderedDict
from collections.abc import Iterable
from typing import Union
from typing import Dict, Iterable, Union
import torch
import torch.distributed as dist
@ -160,7 +159,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
):
"""

View File

@ -1,8 +1,7 @@
# mypy: allow-untyped-defs
# flake8: noqa C101
import itertools
from collections.abc import Iterable, Iterator
from typing import Union
from typing import Dict, Iterable, Iterator, Union
import torch
import torch.distributed as dist
@ -52,7 +51,7 @@ def average_parameters(
def get_params_to_average(
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
):
"""
Return a list of parameters that need to average.
@ -82,7 +81,7 @@ def get_params_to_average(
def average_parameters_or_parameter_groups(
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
],
process_group: ProcessGroup,
):