Revert "[BE]: Update mypy to 1.13.0 (#140808)"

This reverts commit 00134d68af2ce50560fa5a74473665ea229e6c9d.

Reverted https://github.com/pytorch/pytorch/pull/140808 on behalf of https://github.com/huydhn due to This is failing a distributed test in trunk, target determination missed this test and did not run it on PR ([comment](https://github.com/pytorch/pytorch/pull/140808#issuecomment-2512788426))
This commit is contained in:
PyTorch MergeBot
2024-12-02 20:47:43 +00:00
parent 54adbbf6b8
commit daa77f3d9f
31 changed files with 70 additions and 116 deletions

View File

@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:
mypy==1.13.0
mypy==1.11.2
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 1.10.0

View File

@ -144,7 +144,7 @@ init_command = [
'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.2.1',
'mypy==1.13.0',
'mypy==1.11.2',
'sympy==1.13.0 ; python_version >= "3.9"',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',

View File

@ -386,11 +386,9 @@ class Op:
}, f"{type} is not a supported operation"
self.type = type
if type == "send":
assert isinstance(meta, str)
s, d = meta.split("->")
self._src, self._dst = int(s), int(d)
elif type == "recv":
assert isinstance(meta, str)
d, s = meta.split("<-")
self._dst, self._src = int(d), int(s)
else:

View File

@ -1503,7 +1503,6 @@ def export(
# NB: this is wrong if graph_captured_result has
# data-dependent output size!
ignore_fresh_unbacked = null_context()
assert ambient_fake_mode is not None
if shape_env := ambient_fake_mode.shape_env:
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()

View File

@ -11,18 +11,7 @@ import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
import sympy
@ -632,11 +621,8 @@ class OutputGraph:
"""
Saves to out if it is provided. Else saves to the tracing context's global_state.
"""
global_state = cast(
Dict[str, Tuple[Callable[..., Any], bool]],
out
if out is not None
else self.tracing_context.global_context.global_state,
global_state = (
out if out is not None else self.tracing_context.global_context.global_state
)
# TODO - Consider having a torch level API for torch_function_state. As
@ -659,11 +645,11 @@ class OutputGraph:
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
global_state["autocast_gpu_dtype"] = (
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
global_state["autocast_cpu_dtype"] = (
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)

View File

@ -1088,7 +1088,7 @@ class ChromiumEventLogger:
a specification of the Chromium Event JSON format.
"""
def get_stack(self) -> List[str]:
def get_stack(self):
"""
The main event stack, with every chromium event.
Logged to tlparse.
@ -1099,7 +1099,7 @@ class ChromiumEventLogger:
self.tls.stack = []
return self.tls.stack
def get_top(self) -> Optional[str]:
def get_top(self) -> str:
"""
Get the top event name or None if the stack is empty.
"""

View File

@ -166,8 +166,8 @@ def _dump_dynamic_shapes(
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
if root.__name__ not in dims:
dims[root.__name__] = {
"min": root.min, # type: ignore[attr-defined,union-attr]
"max": root.max, # type: ignore[attr-defined,union-attr]
"min": root.min,
"max": root.max,
"derived": set(),
}

View File

@ -2425,7 +2425,7 @@ def _dict_to_dataclass(cls, data):
field_type = cls.__annotations__[_type]
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
elif dataclasses.is_dataclass(cls):
obj = cls(**data) # type: ignore[assignment,operator]
obj = cls(**data) # type: ignore[assignment]
type_hints = typing.get_type_hints(cls)
for f in dataclasses.fields(cls):
name = f.name

View File

@ -292,7 +292,7 @@ def reduction_prefix_array(
acc_type: str,
reduction_type: str,
dtype: torch.dtype,
len: Union[str, int],
len: int,
init_fn,
):
"""

View File

@ -308,8 +308,8 @@ def transpose_w(
def expand_bias(
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
B: Union[ir.IRNode, torch.Tensor], X: Union[ir.IRNode, torch.Tensor]
) -> Union[ir.IRNode, torch.Tensor]:
"""
Expand Bias to the same size of X.
"""
@ -870,7 +870,7 @@ class CppPackedGemmTemplate(CppTemplate):
W = new_inputs[1]
B = new_inputs[2] if has_bias else None
W = transpose_w(W, trans_w)
B = expand_bias(B, X) # type:ignore[arg-type]
B = expand_bias(B, X)
new_inputs[1] = W
if B is not None:
new_inputs[2] = B

View File

@ -382,7 +382,7 @@ def split_const_gm(
gm,
node,
(
const_result[const_outputs[node.name]] # type:ignore[index]
const_result[const_outputs[node.name]]
if lifted_constant_names is None
else None
),

View File

@ -1883,9 +1883,9 @@ class GraphLowering(torch.fx.Interpreter):
# Generating random inputs based on self.example_inputs sometimes can be problematic,
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
real_inputs = [
materialize(x) # type:ignore[arg-type]
materialize(x)
for x in (
self.example_inputs # type:ignore[union-attr]
self.example_inputs
if isinstance(V.real_inputs, NullHandler)
else V.real_inputs
)

View File

@ -1612,7 +1612,7 @@ def reduction_dtypes(
# batched_matrix_contiguous_strides and contiguous_strides
def make_contiguous_strides_for(
shape: ShapeType, row_major: bool = True
) -> Tuple[Union[_IntLikeT, int], ...]:
) -> Tuple[int, ...]:
"""
Returns the strides of a contiguous tensor if row_major
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
@ -1625,13 +1625,11 @@ def make_contiguous_strides_for(
from torch.fx.experimental.symbolic_shapes import is_nested_int
multiplier: Union[_IntLikeT, int] = 1
multiplier = 1
strides = []
for l in reversed(shape):
strides.append(multiplier)
multiplier *= (
l if is_nested_int(l) else sym_max(l, 1)
) # type:ignore[assignment]
multiplier *= l if is_nested_int(l) else sym_max(l, 1)
result = tuple(reversed(strides))

View File

@ -410,7 +410,7 @@ def _broadcast_shapes(*_shapes):
assert isinstance(shape, Sequence)
# Computes common shape
common_shape: List[Union[int, torch.SymInt]] = [
common_shape = [
1,
] * reduce(max, (len(shape) for shape in shapes))
for arg_idx, shape in enumerate(shapes):

View File

@ -20,7 +20,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[str, Tuple[int, ...]],
padding: Tuple[int, ...],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],

View File

@ -35,7 +35,7 @@ def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
res = []
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
param_value = mod._flat_weights[idx].detach() # type: ignore[index]
res.append(param_value)
return res
@ -72,7 +72,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
res = []
for idx, param_name in enumerate(mod._flat_weights_names):
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
param_value = mod._flat_weights[idx].detach()
res.append(param_value)
return res
else:

View File

@ -665,7 +665,7 @@ def _get_output_act_obs_or_fq(
named_modules: Dict[str, torch.nn.Module],
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Optional[ObserverOrFakeQuantize]:
) -> ObserverOrFakeQuantize:
"""Get the constructor for observer or fake quant object for
the argument in the original graph as the output of previous node,
skipping inserted observers

View File

@ -105,7 +105,6 @@ def post_localSGD_hook(
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
if state.iter < state.start_localSGD_iter:
state.maybe_increase_iter(bucket)
assert isinstance(global_group_to_use, dist.ProcessGroup)
return default._allreduce_fut(global_group_to_use, input_tensor)
# If `post_local_gradient_allreduce` is not set,

View File

@ -7,7 +7,6 @@ from typing import Dict
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from torch.utils._typing_utils import not_none
from . import default_hooks as default
@ -399,10 +398,7 @@ def powerSGD_hook(
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = (
process_group if process_group is not None else not_none(dist.group.WORLD)
)
assert isinstance(process_group, dist.ProcessGroup)
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.
@ -711,10 +707,7 @@ def batched_powerSGD_hook(
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
""" # noqa: B950
process_group = state.process_group
group_to_use = (
process_group if process_group is not None else not_none(dist.group.WORLD)
)
assert isinstance(group_to_use, dist.ProcessGroup)
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
# The input tensor is a flattened 1D tensor.

View File

@ -1,12 +1,11 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union
from typing import Dict, Iterable, Union
import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.utils as utils
from torch.utils._typing_utils import not_none as _not_none
__all__ = ["ModelAverager", "PeriodicModelAverager"]
@ -22,9 +21,9 @@ class ModelAverager(ABC):
will be used. (default: ``None``)
"""
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
def __init__(self, process_group=None):
self.process_group = (
process_group if process_group is not None else _not_none(dist.group.WORLD)
process_group if process_group is not None else dist.group.WORLD
)
self.step = 0
@ -86,9 +85,7 @@ class PeriodicModelAverager(ModelAverager):
>>> averager.average_parameters(model.parameters())
"""
def __init__(
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
):
def __init__(self, period, warmup_steps=0, process_group=None):
super().__init__(process_group)
if warmup_steps < 0:
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
@ -123,7 +120,5 @@ class PeriodicModelAverager(ModelAverager):
self.step >= self.warmup_steps
and (self.step - self.warmup_steps) % self.period == 0
):
utils.average_parameters_or_parameter_groups(
params, _not_none(self.process_group)
)
utils.average_parameters_or_parameter_groups(params, self.process_group)
self.step += 1

View File

@ -4477,9 +4477,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
@_exception_logger
def barrier(
group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None
):
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
"""
Synchronize all processes.
@ -4521,11 +4519,7 @@ def barrier(
work.wait()
def monitored_barrier(
group: Optional[ProcessGroup] = GroupMember.WORLD,
timeout=None,
wait_all_ranks=False,
):
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
"""
Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
@ -4595,9 +4589,7 @@ def monitored_barrier(
_check_valid_timeout(timeout)
group_to_use = _get_default_group() if group is None else group
return group_to_use.monitored_barrier( # type:ignore[attr-defined]
timeout, wait_all_ranks=wait_all_ranks
)
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
def _create_process_group_wrapper(

View File

@ -630,7 +630,7 @@ def _flatten_optim_state(
assert state_names is not None
# Flatten the state
flat_state: Dict[str, Optional[torch.Tensor]] = {}
flat_state: Dict[str, Any] = {}
for state_name in state_names:
state_values = [
unflat_param_state[state_name] if unflat_param_state is not None else None
@ -658,7 +658,7 @@ def _flatten_optim_state(
if are_pos_dim_tensors:
flat_tensor = _flatten_tensor_optim_state(
state_name,
state_values, # type: ignore[arg-type]
state_values,
unflat_param_names,
unflat_param_shapes,
handle,
@ -680,7 +680,7 @@ def _flatten_optim_state(
elif are_zero_dim_tensors:
flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
state_name,
state_values, # type: ignore[arg-type]
state_values,
unflat_param_names,
)
else:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Sized, Tuple
from typing import cast, List, Optional, Sequence, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -769,7 +769,7 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
),
)
def size_split(N, i) -> List:
def size_split(N, i):
# Last chunk will be smaller if the tensor size N
# along the given dimension dim is not divisible by i.
assert i > 0
@ -780,7 +780,6 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
if isinstance(split_size_or_sections, int)
else split_size_or_sections
)
assert isinstance(output_size_list, Sized)
output_spec_list = [
DTensorSpec(
mesh=input_spec.mesh,

View File

@ -872,7 +872,7 @@ class Tracer(TracerBase):
nonlocal cnt
cnt += 1
param = sig.parameters[name]
default: Tuple[Any, ...] = (
default = (
() if param.default is inspect.Parameter.empty else (param.default,)
)
out = self.create_proxy(
@ -913,7 +913,7 @@ class Tracer(TracerBase):
return pytree.tree_map(replace_ph, concrete_args[name])
if name[0] == "*":
default: Tuple[Any, ...] = ()
default = ()
else:
param = sig.parameters[name]
default = ( # type: ignore[assignment]

View File

@ -1190,7 +1190,7 @@ def wrap_key(
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
out = f(*tensors) # type:ignore[call-arg]
out = f(*tensors)
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
out = pytree.tree_map_only(
_AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out

View File

@ -233,9 +233,7 @@ class SymIntEqByExpr:
return hash(self._extract())
def _nested_int_aware_sort(
tup: Tuple[Union[SymInt, int], int]
) -> Tuple[int, Union[SymInt, int], int]:
def _nested_int_aware_sort(tup: Tuple[Union[SymInt, int], int]) -> Tuple[int, int, int]:
return (
# Order nested ints by their coefficients.
# 1 here to order nested ints after non-nested-ints.

View File

@ -344,10 +344,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
fn_def = parsed_def.ast.body[0]
if is_classmethod:
arg_name = fn_def.args.args[0].arg # type:ignore[union-attr]
arg_name = fn_def.args.args[0].arg
# Insert a statement that assigns the first argument to the class
assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
fn_def.body.insert(0, assign_stmt) # type:ignore[union-attr]
fn_def.body.insert(0, assign_stmt)
# Swap out the function signature and body if it is unused
if should_drop(fn):
@ -361,16 +361,16 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}"
)
unused_def = unused_fn_def.body[0]
fn_def.body = unused_def.body # type:ignore[union-attr]
fn_def.body = unused_def.body
# kwarg/vararg not supported by `build_def`
fn_def.args.kwarg = fn_def.args.vararg = None # type:ignore[union-attr]
for arg in fn_def.args.args + fn_def.args.kwonlyargs: # type:ignore[union-attr]
fn_def.args.kwarg = fn_def.args.vararg = None
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
# Replace potentially unsupported type annotations by "Any"
arg.annotation = unused_def.args.args[0].annotation
if _is_drop_fn(fn):
# Dropping potentially unsupported return type annotation for jit._drop
fn_def.returns = None # type:ignore[union-attr]
fn_def.type_comment = None # type:ignore[union-attr]
fn_def.returns = None
fn_def.type_comment = None
# If MonkeyType is installed, get all the consolidated type traces
# for the arguments from type_trace_db

View File

@ -89,7 +89,7 @@ class _ConvNd(Module):
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[str, Tuple[int, ...]],
padding: Tuple[int, ...],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],

View File

@ -237,8 +237,8 @@ class RNNBase(Module):
# Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
# or the tensors in _flat_weights are of different dtypes
first_fw = self._flat_weights[0] # type: ignore[union-attr]
dtype = first_fw.dtype # type: ignore[union-attr]
first_fw = self._flat_weights[0]
dtype = first_fw.dtype
for fw in self._flat_weights:
if (
not isinstance(fw, Tensor)
@ -252,9 +252,7 @@ class RNNBase(Module):
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = {
p.data_ptr() for p in self._flat_weights # type: ignore[union-attr]
}
unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
if len(unique_data_ptrs) != len(self._flat_weights):
return
@ -269,7 +267,7 @@ class RNNBase(Module):
if self.proj_size > 0:
num_weights += 1
torch._cudnn_rnn_flatten_weight(
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
num_weights,
self.input_size,
rnn.get_cudnn_mode(self.mode),
@ -299,11 +297,11 @@ class RNNBase(Module):
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
if not torch.jit.is_scripting():
if (
input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr]
input.dtype != self._flat_weights[0].dtype
and not torch._C._is_any_autocast_enabled()
):
raise ValueError(
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr]
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}"
)
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
@ -716,7 +714,7 @@ class RNN(RNNBase):
result = _VF.rnn_tanh(
input,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -728,7 +726,7 @@ class RNN(RNNBase):
result = _VF.rnn_relu(
input,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -742,7 +740,7 @@ class RNN(RNNBase):
input,
batch_sizes,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -754,7 +752,7 @@ class RNN(RNNBase):
input,
batch_sizes,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -1124,7 +1122,7 @@ class LSTM(RNNBase):
result = _VF.lstm(
input,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -1137,7 +1135,7 @@ class LSTM(RNNBase):
input,
batch_sizes,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -1393,7 +1391,7 @@ class GRU(RNNBase):
result = _VF.gru(
input,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,
@ -1406,7 +1404,7 @@ class GRU(RNNBase):
input,
batch_sizes,
hx,
self._flat_weights, # type: ignore[arg-type]
self._flat_weights,
self.bias,
self.num_layers,
self.dropout,

View File

@ -511,7 +511,7 @@ def _multi_tensor_radam(
del bias_correction1
else:
rect = [
( # type: ignore[misc]
(
(rho_t - 4) # type: ignore[arg-type]
* (rho_t - 2)
* rho_inf

View File

@ -15,7 +15,6 @@ from functools import wraps
from typing import (
Any,
Callable,
cast,
Dict,
List,
no_type_check,
@ -203,7 +202,7 @@ def _broadcast_state_dict(rank, state_dict):
olist = [state_dict if rank == 0 else None]
dist.broadcast_object_list(olist)
state_dict = cast(Dict[str, torch.Tensor], olist[0])
state_dict = olist[0]
# Ensure that the state is on DEVICE
for param_name in state_dict.keys():
state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)