[BE]: Update mypy to 1.13.0 (#140808)

Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Aaron Gokaslan
2024-12-02 18:47:54 +00:00
committed by PyTorch MergeBot
parent 9012e7a62f
commit 00134d68af
31 changed files with 116 additions and 70 deletions

View File

@ -105,6 +105,7 @@ def post_localSGD_hook(
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
if state.iter < state.start_localSGD_iter:
state.maybe_increase_iter(bucket)
assert isinstance(global_group_to_use, dist.ProcessGroup)
return default._allreduce_fut(global_group_to_use, input_tensor)
# If `post_local_gradient_allreduce` is not set,

View File

@ -7,6 +7,7 @@ from typing import Dict
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from torch.utils._typing_utils import not_none
from . import default_hooks as default
@ -398,7 +399,10 @@ def powerSGD_hook(
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
group_to_use = (
process_group if process_group is not None else not_none(dist.group.WORLD)
)
assert isinstance(process_group, dist.ProcessGroup)
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.
@ -707,7 +711,10 @@ def batched_powerSGD_hook(
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = process_group if process_group is not None else dist.group.WORLD
group_to_use = (
process_group if process_group is not None else not_none(dist.group.WORLD)
)
assert isinstance(group_to_use, dist.ProcessGroup)
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.

View File

@ -1,11 +1,12 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union
import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.utils as utils
from torch.utils._typing_utils import not_none as _not_none
__all__ = ["ModelAverager", "PeriodicModelAverager"]
@ -21,9 +22,9 @@ class ModelAverager(ABC):
will be used. (default: ``None``)
"""
def __init__(self, process_group=None):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
self.process_group = (
process_group if process_group is not None else dist.group.WORLD
process_group if process_group is not None else _not_none(dist.group.WORLD)
)
self.step = 0
@ -85,7 +86,9 @@ class PeriodicModelAverager(ModelAverager):
>>> averager.average_parameters(model.parameters())
"""
def __init__(self, period, warmup_steps=0, process_group=None):
def __init__(
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
):
super().__init__(process_group)
if warmup_steps < 0:
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
@ -120,5 +123,7 @@ class PeriodicModelAverager(ModelAverager):
self.step >= self.warmup_steps
and (self.step - self.warmup_steps) % self.period == 0
):
utils.average_parameters_or_parameter_groups(params, self.process_group)
utils.average_parameters_or_parameter_groups(
params, _not_none(self.process_group)
)
self.step += 1