mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE]: Update mypy to 1.13.0 (#140808)"
This reverts commit 00134d68af2ce50560fa5a74473665ea229e6c9d. Reverted https://github.com/pytorch/pytorch/pull/140808 on behalf of https://github.com/huydhn due to This is failing a distributed test in trunk, target determination missed this test and did not run it on PR ([comment](https://github.com/pytorch/pytorch/pull/140808#issuecomment-2512788426))
This commit is contained in:
@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==1.13.0
|
||||
mypy==1.11.2
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
#Description: linter
|
||||
#Pinned versions: 1.10.0
|
||||
|
@ -144,7 +144,7 @@ init_command = [
|
||||
'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.2.1',
|
||||
'mypy==1.13.0',
|
||||
'mypy==1.11.2',
|
||||
'sympy==1.13.0 ; python_version >= "3.9"',
|
||||
'types-requests==2.27.25',
|
||||
'types-PyYAML==6.0.7',
|
||||
|
@ -386,11 +386,9 @@ class Op:
|
||||
}, f"{type} is not a supported operation"
|
||||
self.type = type
|
||||
if type == "send":
|
||||
assert isinstance(meta, str)
|
||||
s, d = meta.split("->")
|
||||
self._src, self._dst = int(s), int(d)
|
||||
elif type == "recv":
|
||||
assert isinstance(meta, str)
|
||||
d, s = meta.split("<-")
|
||||
self._dst, self._src = int(d), int(s)
|
||||
else:
|
||||
|
@ -1503,7 +1503,6 @@ def export(
|
||||
# NB: this is wrong if graph_captured_result has
|
||||
# data-dependent output size!
|
||||
ignore_fresh_unbacked = null_context()
|
||||
assert ambient_fake_mode is not None
|
||||
if shape_env := ambient_fake_mode.shape_env:
|
||||
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
|
||||
|
||||
|
@ -11,18 +11,7 @@ import sys
|
||||
import traceback
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -632,11 +621,8 @@ class OutputGraph:
|
||||
"""
|
||||
Saves to out if it is provided. Else saves to the tracing context's global_state.
|
||||
"""
|
||||
global_state = cast(
|
||||
Dict[str, Tuple[Callable[..., Any], bool]],
|
||||
out
|
||||
if out is not None
|
||||
else self.tracing_context.global_context.global_state,
|
||||
global_state = (
|
||||
out if out is not None else self.tracing_context.global_context.global_state
|
||||
)
|
||||
|
||||
# TODO - Consider having a torch level API for torch_function_state. As
|
||||
@ -659,11 +645,11 @@ class OutputGraph:
|
||||
functools.partial(torch.set_autocast_enabled, "cpu"),
|
||||
torch.is_autocast_enabled("cpu"),
|
||||
)
|
||||
global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
|
||||
global_state["autocast_gpu_dtype"] = (
|
||||
functools.partial(torch.set_autocast_dtype, "cuda"),
|
||||
torch.get_autocast_dtype("cuda"),
|
||||
)
|
||||
global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
|
||||
global_state["autocast_cpu_dtype"] = (
|
||||
functools.partial(torch.set_autocast_dtype, "cpu"),
|
||||
torch.get_autocast_dtype("cpu"),
|
||||
)
|
||||
|
@ -1088,7 +1088,7 @@ class ChromiumEventLogger:
|
||||
a specification of the Chromium Event JSON format.
|
||||
"""
|
||||
|
||||
def get_stack(self) -> List[str]:
|
||||
def get_stack(self):
|
||||
"""
|
||||
The main event stack, with every chromium event.
|
||||
Logged to tlparse.
|
||||
@ -1099,7 +1099,7 @@ class ChromiumEventLogger:
|
||||
self.tls.stack = []
|
||||
return self.tls.stack
|
||||
|
||||
def get_top(self) -> Optional[str]:
|
||||
def get_top(self) -> str:
|
||||
"""
|
||||
Get the top event name or None if the stack is empty.
|
||||
"""
|
||||
|
@ -166,8 +166,8 @@ def _dump_dynamic_shapes(
|
||||
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
|
||||
if root.__name__ not in dims:
|
||||
dims[root.__name__] = {
|
||||
"min": root.min, # type: ignore[attr-defined,union-attr]
|
||||
"max": root.max, # type: ignore[attr-defined,union-attr]
|
||||
"min": root.min,
|
||||
"max": root.max,
|
||||
"derived": set(),
|
||||
}
|
||||
|
||||
|
@ -2425,7 +2425,7 @@ def _dict_to_dataclass(cls, data):
|
||||
field_type = cls.__annotations__[_type]
|
||||
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
|
||||
elif dataclasses.is_dataclass(cls):
|
||||
obj = cls(**data) # type: ignore[assignment,operator]
|
||||
obj = cls(**data) # type: ignore[assignment]
|
||||
type_hints = typing.get_type_hints(cls)
|
||||
for f in dataclasses.fields(cls):
|
||||
name = f.name
|
||||
|
@ -292,7 +292,7 @@ def reduction_prefix_array(
|
||||
acc_type: str,
|
||||
reduction_type: str,
|
||||
dtype: torch.dtype,
|
||||
len: Union[str, int],
|
||||
len: int,
|
||||
init_fn,
|
||||
):
|
||||
"""
|
||||
|
@ -308,8 +308,8 @@ def transpose_w(
|
||||
|
||||
|
||||
def expand_bias(
|
||||
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
|
||||
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
|
||||
B: Union[ir.IRNode, torch.Tensor], X: Union[ir.IRNode, torch.Tensor]
|
||||
) -> Union[ir.IRNode, torch.Tensor]:
|
||||
"""
|
||||
Expand Bias to the same size of X.
|
||||
"""
|
||||
@ -870,7 +870,7 @@ class CppPackedGemmTemplate(CppTemplate):
|
||||
W = new_inputs[1]
|
||||
B = new_inputs[2] if has_bias else None
|
||||
W = transpose_w(W, trans_w)
|
||||
B = expand_bias(B, X) # type:ignore[arg-type]
|
||||
B = expand_bias(B, X)
|
||||
new_inputs[1] = W
|
||||
if B is not None:
|
||||
new_inputs[2] = B
|
||||
|
@ -382,7 +382,7 @@ def split_const_gm(
|
||||
gm,
|
||||
node,
|
||||
(
|
||||
const_result[const_outputs[node.name]] # type:ignore[index]
|
||||
const_result[const_outputs[node.name]]
|
||||
if lifted_constant_names is None
|
||||
else None
|
||||
),
|
||||
|
@ -1883,9 +1883,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# Generating random inputs based on self.example_inputs sometimes can be problematic,
|
||||
# e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
|
||||
real_inputs = [
|
||||
materialize(x) # type:ignore[arg-type]
|
||||
materialize(x)
|
||||
for x in (
|
||||
self.example_inputs # type:ignore[union-attr]
|
||||
self.example_inputs
|
||||
if isinstance(V.real_inputs, NullHandler)
|
||||
else V.real_inputs
|
||||
)
|
||||
|
@ -1612,7 +1612,7 @@ def reduction_dtypes(
|
||||
# batched_matrix_contiguous_strides and contiguous_strides
|
||||
def make_contiguous_strides_for(
|
||||
shape: ShapeType, row_major: bool = True
|
||||
) -> Tuple[Union[_IntLikeT, int], ...]:
|
||||
) -> Tuple[int, ...]:
|
||||
"""
|
||||
Returns the strides of a contiguous tensor if row_major
|
||||
If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
|
||||
@ -1625,13 +1625,11 @@ def make_contiguous_strides_for(
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
||||
|
||||
multiplier: Union[_IntLikeT, int] = 1
|
||||
multiplier = 1
|
||||
strides = []
|
||||
for l in reversed(shape):
|
||||
strides.append(multiplier)
|
||||
multiplier *= (
|
||||
l if is_nested_int(l) else sym_max(l, 1)
|
||||
) # type:ignore[assignment]
|
||||
multiplier *= l if is_nested_int(l) else sym_max(l, 1)
|
||||
|
||||
result = tuple(reversed(strides))
|
||||
|
||||
|
@ -410,7 +410,7 @@ def _broadcast_shapes(*_shapes):
|
||||
assert isinstance(shape, Sequence)
|
||||
|
||||
# Computes common shape
|
||||
common_shape: List[Union[int, torch.SymInt]] = [
|
||||
common_shape = [
|
||||
1,
|
||||
] * reduce(max, (len(shape) for shape in shapes))
|
||||
for arg_idx, shape in enumerate(shapes):
|
||||
|
@ -20,7 +20,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
|
||||
out_channels: int,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[str, Tuple[int, ...]],
|
||||
padding: Tuple[int, ...],
|
||||
dilation: Tuple[int, ...],
|
||||
transposed: bool,
|
||||
output_padding: Tuple[int, ...],
|
||||
|
@ -35,7 +35,7 @@ def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
|
||||
res = []
|
||||
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
|
||||
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
|
||||
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
|
||||
param_value = mod._flat_weights[idx].detach() # type: ignore[index]
|
||||
res.append(param_value)
|
||||
return res
|
||||
|
||||
@ -72,7 +72,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
|
||||
res = []
|
||||
for idx, param_name in enumerate(mod._flat_weights_names):
|
||||
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
|
||||
param_value = mod._flat_weights[idx].detach() # type: ignore[index,union-attr]
|
||||
param_value = mod._flat_weights[idx].detach()
|
||||
res.append(param_value)
|
||||
return res
|
||||
else:
|
||||
|
@ -665,7 +665,7 @@ def _get_output_act_obs_or_fq(
|
||||
named_modules: Dict[str, torch.nn.Module],
|
||||
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||||
is_qat: bool,
|
||||
) -> Optional[ObserverOrFakeQuantize]:
|
||||
) -> ObserverOrFakeQuantize:
|
||||
"""Get the constructor for observer or fake quant object for
|
||||
the argument in the original graph as the output of previous node,
|
||||
skipping inserted observers
|
||||
|
@ -105,7 +105,6 @@ def post_localSGD_hook(
|
||||
# Run allreduce using `global_group_to_use` in the first `start_localSGD_iter` iterations.
|
||||
if state.iter < state.start_localSGD_iter:
|
||||
state.maybe_increase_iter(bucket)
|
||||
assert isinstance(global_group_to_use, dist.ProcessGroup)
|
||||
return default._allreduce_fut(global_group_to_use, input_tensor)
|
||||
|
||||
# If `post_local_gradient_allreduce` is not set,
|
||||
|
@ -7,7 +7,6 @@ from typing import Dict
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
from . import default_hooks as default
|
||||
|
||||
@ -399,10 +398,7 @@ def powerSGD_hook(
|
||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||
""" # noqa: B950
|
||||
process_group = state.process_group
|
||||
group_to_use = (
|
||||
process_group if process_group is not None else not_none(dist.group.WORLD)
|
||||
)
|
||||
assert isinstance(process_group, dist.ProcessGroup)
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
|
||||
# The input tensor is a flattened 1D tensor.
|
||||
@ -711,10 +707,7 @@ def batched_powerSGD_hook(
|
||||
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
||||
""" # noqa: B950
|
||||
process_group = state.process_group
|
||||
group_to_use = (
|
||||
process_group if process_group is not None else not_none(dist.group.WORLD)
|
||||
)
|
||||
assert isinstance(group_to_use, dist.ProcessGroup)
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
|
||||
# The input tensor is a flattened 1D tensor.
|
||||
|
@ -1,12 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
from typing import Dict, Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.model_averaging.utils as utils
|
||||
from torch.utils._typing_utils import not_none as _not_none
|
||||
|
||||
|
||||
__all__ = ["ModelAverager", "PeriodicModelAverager"]
|
||||
@ -22,9 +21,9 @@ class ModelAverager(ABC):
|
||||
will be used. (default: ``None``)
|
||||
"""
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None):
|
||||
def __init__(self, process_group=None):
|
||||
self.process_group = (
|
||||
process_group if process_group is not None else _not_none(dist.group.WORLD)
|
||||
process_group if process_group is not None else dist.group.WORLD
|
||||
)
|
||||
self.step = 0
|
||||
|
||||
@ -86,9 +85,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
>>> averager.average_parameters(model.parameters())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, period, warmup_steps=0, process_group: Optional[dist.ProcessGroup] = None
|
||||
):
|
||||
def __init__(self, period, warmup_steps=0, process_group=None):
|
||||
super().__init__(process_group)
|
||||
if warmup_steps < 0:
|
||||
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
|
||||
@ -123,7 +120,5 @@ class PeriodicModelAverager(ModelAverager):
|
||||
self.step >= self.warmup_steps
|
||||
and (self.step - self.warmup_steps) % self.period == 0
|
||||
):
|
||||
utils.average_parameters_or_parameter_groups(
|
||||
params, _not_none(self.process_group)
|
||||
)
|
||||
utils.average_parameters_or_parameter_groups(params, self.process_group)
|
||||
self.step += 1
|
||||
|
@ -4477,9 +4477,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
|
||||
|
||||
|
||||
@_exception_logger
|
||||
def barrier(
|
||||
group: Optional[ProcessGroup] = GroupMember.WORLD, async_op=False, device_ids=None
|
||||
):
|
||||
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
|
||||
"""
|
||||
Synchronize all processes.
|
||||
|
||||
@ -4521,11 +4519,7 @@ def barrier(
|
||||
work.wait()
|
||||
|
||||
|
||||
def monitored_barrier(
|
||||
group: Optional[ProcessGroup] = GroupMember.WORLD,
|
||||
timeout=None,
|
||||
wait_all_ranks=False,
|
||||
):
|
||||
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
|
||||
"""
|
||||
Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
|
||||
|
||||
@ -4595,9 +4589,7 @@ def monitored_barrier(
|
||||
_check_valid_timeout(timeout)
|
||||
|
||||
group_to_use = _get_default_group() if group is None else group
|
||||
return group_to_use.monitored_barrier( # type:ignore[attr-defined]
|
||||
timeout, wait_all_ranks=wait_all_ranks
|
||||
)
|
||||
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
|
||||
|
||||
|
||||
def _create_process_group_wrapper(
|
||||
|
@ -630,7 +630,7 @@ def _flatten_optim_state(
|
||||
assert state_names is not None
|
||||
|
||||
# Flatten the state
|
||||
flat_state: Dict[str, Optional[torch.Tensor]] = {}
|
||||
flat_state: Dict[str, Any] = {}
|
||||
for state_name in state_names:
|
||||
state_values = [
|
||||
unflat_param_state[state_name] if unflat_param_state is not None else None
|
||||
@ -658,7 +658,7 @@ def _flatten_optim_state(
|
||||
if are_pos_dim_tensors:
|
||||
flat_tensor = _flatten_tensor_optim_state(
|
||||
state_name,
|
||||
state_values, # type: ignore[arg-type]
|
||||
state_values,
|
||||
unflat_param_names,
|
||||
unflat_param_shapes,
|
||||
handle,
|
||||
@ -680,7 +680,7 @@ def _flatten_optim_state(
|
||||
elif are_zero_dim_tensors:
|
||||
flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
|
||||
state_name,
|
||||
state_values, # type: ignore[arg-type]
|
||||
state_values,
|
||||
unflat_param_names,
|
||||
)
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import cast, List, Optional, Sequence, Sized, Tuple
|
||||
from typing import cast, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
@ -769,7 +769,7 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
|
||||
),
|
||||
)
|
||||
|
||||
def size_split(N, i) -> List:
|
||||
def size_split(N, i):
|
||||
# Last chunk will be smaller if the tensor size N
|
||||
# along the given dimension dim is not divisible by i.
|
||||
assert i > 0
|
||||
@ -780,7 +780,6 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
|
||||
if isinstance(split_size_or_sections, int)
|
||||
else split_size_or_sections
|
||||
)
|
||||
assert isinstance(output_size_list, Sized)
|
||||
output_spec_list = [
|
||||
DTensorSpec(
|
||||
mesh=input_spec.mesh,
|
||||
|
@ -872,7 +872,7 @@ class Tracer(TracerBase):
|
||||
nonlocal cnt
|
||||
cnt += 1
|
||||
param = sig.parameters[name]
|
||||
default: Tuple[Any, ...] = (
|
||||
default = (
|
||||
() if param.default is inspect.Parameter.empty else (param.default,)
|
||||
)
|
||||
out = self.create_proxy(
|
||||
@ -913,7 +913,7 @@ class Tracer(TracerBase):
|
||||
|
||||
return pytree.tree_map(replace_ph, concrete_args[name])
|
||||
if name[0] == "*":
|
||||
default: Tuple[Any, ...] = ()
|
||||
default = ()
|
||||
else:
|
||||
param = sig.parameters[name]
|
||||
default = ( # type: ignore[assignment]
|
||||
|
@ -1190,7 +1190,7 @@ def wrap_key(
|
||||
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
||||
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
|
||||
|
||||
out = f(*tensors) # type:ignore[call-arg]
|
||||
out = f(*tensors)
|
||||
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
|
||||
out = pytree.tree_map_only(
|
||||
_AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out
|
||||
|
@ -233,9 +233,7 @@ class SymIntEqByExpr:
|
||||
return hash(self._extract())
|
||||
|
||||
|
||||
def _nested_int_aware_sort(
|
||||
tup: Tuple[Union[SymInt, int], int]
|
||||
) -> Tuple[int, Union[SymInt, int], int]:
|
||||
def _nested_int_aware_sort(tup: Tuple[Union[SymInt, int], int]) -> Tuple[int, int, int]:
|
||||
return (
|
||||
# Order nested ints by their coefficients.
|
||||
# 1 here to order nested ints after non-nested-ints.
|
||||
|
@ -344,10 +344,10 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
fn_def = parsed_def.ast.body[0]
|
||||
|
||||
if is_classmethod:
|
||||
arg_name = fn_def.args.args[0].arg # type:ignore[union-attr]
|
||||
arg_name = fn_def.args.args[0].arg
|
||||
# Insert a statement that assigns the first argument to the class
|
||||
assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
|
||||
fn_def.body.insert(0, assign_stmt) # type:ignore[union-attr]
|
||||
fn_def.body.insert(0, assign_stmt)
|
||||
|
||||
# Swap out the function signature and body if it is unused
|
||||
if should_drop(fn):
|
||||
@ -361,16 +361,16 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
||||
f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}"
|
||||
)
|
||||
unused_def = unused_fn_def.body[0]
|
||||
fn_def.body = unused_def.body # type:ignore[union-attr]
|
||||
fn_def.body = unused_def.body
|
||||
# kwarg/vararg not supported by `build_def`
|
||||
fn_def.args.kwarg = fn_def.args.vararg = None # type:ignore[union-attr]
|
||||
for arg in fn_def.args.args + fn_def.args.kwonlyargs: # type:ignore[union-attr]
|
||||
fn_def.args.kwarg = fn_def.args.vararg = None
|
||||
for arg in fn_def.args.args + fn_def.args.kwonlyargs:
|
||||
# Replace potentially unsupported type annotations by "Any"
|
||||
arg.annotation = unused_def.args.args[0].annotation
|
||||
if _is_drop_fn(fn):
|
||||
# Dropping potentially unsupported return type annotation for jit._drop
|
||||
fn_def.returns = None # type:ignore[union-attr]
|
||||
fn_def.type_comment = None # type:ignore[union-attr]
|
||||
fn_def.returns = None
|
||||
fn_def.type_comment = None
|
||||
|
||||
# If MonkeyType is installed, get all the consolidated type traces
|
||||
# for the arguments from type_trace_db
|
||||
|
@ -89,7 +89,7 @@ class _ConvNd(Module):
|
||||
out_channels: int,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[str, Tuple[int, ...]],
|
||||
padding: Tuple[int, ...],
|
||||
dilation: Tuple[int, ...],
|
||||
transposed: bool,
|
||||
output_padding: Tuple[int, ...],
|
||||
|
@ -237,8 +237,8 @@ class RNNBase(Module):
|
||||
# Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
|
||||
# or the tensors in _flat_weights are of different dtypes
|
||||
|
||||
first_fw = self._flat_weights[0] # type: ignore[union-attr]
|
||||
dtype = first_fw.dtype # type: ignore[union-attr]
|
||||
first_fw = self._flat_weights[0]
|
||||
dtype = first_fw.dtype
|
||||
for fw in self._flat_weights:
|
||||
if (
|
||||
not isinstance(fw, Tensor)
|
||||
@ -252,9 +252,7 @@ class RNNBase(Module):
|
||||
# a sufficient check, because overlapping parameter buffers that don't completely
|
||||
# alias would break the assumptions of the uniqueness check in
|
||||
# Module.named_parameters().
|
||||
unique_data_ptrs = {
|
||||
p.data_ptr() for p in self._flat_weights # type: ignore[union-attr]
|
||||
}
|
||||
unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
|
||||
if len(unique_data_ptrs) != len(self._flat_weights):
|
||||
return
|
||||
|
||||
@ -269,7 +267,7 @@ class RNNBase(Module):
|
||||
if self.proj_size > 0:
|
||||
num_weights += 1
|
||||
torch._cudnn_rnn_flatten_weight(
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
num_weights,
|
||||
self.input_size,
|
||||
rnn.get_cudnn_mode(self.mode),
|
||||
@ -299,11 +297,11 @@ class RNNBase(Module):
|
||||
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
||||
if not torch.jit.is_scripting():
|
||||
if (
|
||||
input.dtype != self._flat_weights[0].dtype # type: ignore[union-attr]
|
||||
input.dtype != self._flat_weights[0].dtype
|
||||
and not torch._C._is_any_autocast_enabled()
|
||||
):
|
||||
raise ValueError(
|
||||
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" # type: ignore[union-attr]
|
||||
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}"
|
||||
)
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
if input.dim() != expected_input_dim:
|
||||
@ -716,7 +714,7 @@ class RNN(RNNBase):
|
||||
result = _VF.rnn_tanh(
|
||||
input,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -728,7 +726,7 @@ class RNN(RNNBase):
|
||||
result = _VF.rnn_relu(
|
||||
input,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -742,7 +740,7 @@ class RNN(RNNBase):
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -754,7 +752,7 @@ class RNN(RNNBase):
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -1124,7 +1122,7 @@ class LSTM(RNNBase):
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -1137,7 +1135,7 @@ class LSTM(RNNBase):
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -1393,7 +1391,7 @@ class GRU(RNNBase):
|
||||
result = _VF.gru(
|
||||
input,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
@ -1406,7 +1404,7 @@ class GRU(RNNBase):
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self._flat_weights, # type: ignore[arg-type]
|
||||
self._flat_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
|
@ -511,7 +511,7 @@ def _multi_tensor_radam(
|
||||
del bias_correction1
|
||||
else:
|
||||
rect = [
|
||||
( # type: ignore[misc]
|
||||
(
|
||||
(rho_t - 4) # type: ignore[arg-type]
|
||||
* (rho_t - 2)
|
||||
* rho_inf
|
||||
|
@ -15,7 +15,6 @@ from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
no_type_check,
|
||||
@ -203,7 +202,7 @@ def _broadcast_state_dict(rank, state_dict):
|
||||
|
||||
olist = [state_dict if rank == 0 else None]
|
||||
dist.broadcast_object_list(olist)
|
||||
state_dict = cast(Dict[str, torch.Tensor], olist[0])
|
||||
state_dict = olist[0]
|
||||
# Ensure that the state is on DEVICE
|
||||
for param_name in state_dict.keys():
|
||||
state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
|
||||
|
Reference in New Issue
Block a user