mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[BE]: Update mypy to 1.13.0 (#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
9012e7a62f
commit
00134d68af
@ -105,6 +105,7 @@ def post_localSGD_hook(
|
||||
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
|
||||
if state.iter < state.start_localSGD_iter:
|
||||
state.maybe_increase_iter(bucket)
|
||||
assert isinstance(global_group_to_use, dist.ProcessGroup)
|
||||
return default._allreduce_fut(global_group_to_use, input_tensor)
|
||||
|
||||
# If `post_local_gradient_allreduce` is not set,
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
from . import default_hooks as default
|
||||
|
||||
@ -398,7 +399,10 @@ def powerSGD_hook(
|
||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||
""" # noqa: B950
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
group_to_use = (
|
||||
process_group if process_group is not None else not_none(dist.group.WORLD)
|
||||
)
|
||||
assert isinstance(process_group, dist.ProcessGroup)
|
||||
world_size = group_to_use.size()
|
||||
|
||||
# The input tensor is a flattened 1D tensor.
|
||||
@ -707,7 +711,10 @@ def batched_powerSGD_hook(
|
||||
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
||||
""" # noqa: B950
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
group_to_use = (
|
||||
process_group if process_group is not None else not_none(dist.group.WORLD)
|
||||
)
|
||||
assert isinstance(group_to_use, dist.ProcessGroup)
|
||||
world_size = group_to_use.size()
|
||||
|
||||
# The input tensor is a flattened 1D tensor.
|
||||
|
@ -1,11 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterable, Union
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.model_averaging.utils as utils
|
||||
from torch.utils._typing_utils import not_none as _not_none
|
||||
|
||||
|
||||
__all__ = ["ModelAverager", "PeriodicModelAverager"]
|
||||
@ -21,9 +22,9 @@ class ModelAverager(ABC):
|
||||
will be used. (default: ``None``)
|
||||
"""
|
||||
|
||||
def __init__(self, process_group=None):
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
||||
self.process_group = (
|
||||
process_group if process_group is not None else dist.group.WORLD
|
||||
process_group if process_group is not None else _not_none(dist.group.WORLD)
|
||||
)
|
||||
self.step = 0
|
||||
|
||||
@ -85,7 +86,9 @@ class PeriodicModelAverager(ModelAverager):
|
||||
>>> averager.average_parameters(model.parameters())
|
||||
"""
|
||||
|
||||
def __init__(self, period, warmup_steps=0, process_group=None):
|
||||
def __init__(
|
||||
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
|
||||
):
|
||||
super().__init__(process_group)
|
||||
if warmup_steps < 0:
|
||||
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
|
||||
@ -120,5 +123,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
self.step >= self.warmup_steps
|
||||
and (self.step - self.warmup_steps) % self.period == 0
|
||||
):
|
||||
utils.average_parameters_or_parameter_groups(params, self.process_group)
|
||||
utils.average_parameters_or_parameter_groups(
|
||||
params, _not_none(self.process_group)
|
||||
)
|
||||
self.step += 1
|
||||
|
Reference in New Issue
Block a user