mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
BE: Type previously untyped decorators (#154515)
Summary: Cloned #153726 from Skylion007 and fixed internal typing issues. Test Plan: Unit tests pass Differential Revision: D75477355 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154515 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ba0a91b3ea
commit
946a4c2bdc
@ -31,8 +31,10 @@ from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by tor
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
||||
@ -47,6 +49,9 @@ from torch._sources import fake_range, get_source_lines_and_file, parse_def
|
||||
from torch.futures import Future
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
|
||||
|
||||
BuiltinUnionType: Union[type, tuple[type, ...]]
|
||||
@ -665,7 +670,7 @@ class FunctionModifiers:
|
||||
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
|
||||
|
||||
|
||||
def export(fn):
|
||||
def export(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
|
||||
:class:`ScriptModule` and should be compiled.
|
||||
@ -707,11 +712,11 @@ def export(fn):
|
||||
# any compiled methods and wasn't decorated with `@torch.jit.export`
|
||||
m = torch.jit.script(MyModule())
|
||||
"""
|
||||
fn._torchscript_modifier = FunctionModifiers.EXPORT
|
||||
fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
def unused(fn):
|
||||
def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
This decorator indicates to the compiler that a function or method should
|
||||
be ignored and replaced with the raising of an exception. This allows you
|
||||
@ -764,7 +769,7 @@ def unused(fn):
|
||||
|
||||
return prop
|
||||
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
@ -882,13 +887,13 @@ def ignore(drop=False, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _drop(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers._DROP
|
||||
def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
def _copy_to_script_wrapper(fn):
|
||||
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
|
||||
def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
||||
|
||||
|
@ -169,13 +169,13 @@ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
|
||||
False,
|
||||
qconfig,
|
||||
)
|
||||
qat_linearbn.weight = linear.weight
|
||||
qat_linearbn.bias = linear.bias
|
||||
qat_linearbn.bn.weight = bn.weight
|
||||
qat_linearbn.bn.bias = bn.bias
|
||||
qat_linearbn.bn.running_mean = bn.running_mean
|
||||
qat_linearbn.bn.running_var = bn.running_var
|
||||
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
|
||||
qat_linearbn.weight = linear.weight # type: ignore[assignment]
|
||||
qat_linearbn.bias = linear.bias # type: ignore[assignment]
|
||||
qat_linearbn.bn.weight = bn.weight # type: ignore[assignment]
|
||||
qat_linearbn.bn.bias = bn.bias # type: ignore[assignment]
|
||||
qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment]
|
||||
qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment]
|
||||
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment]
|
||||
return qat_linearbn
|
||||
|
||||
def to_float(self):
|
||||
|
@ -99,7 +99,7 @@ class LinearLeakyReLU(nnq.Linear):
|
||||
activation_post_process = mod.activation_post_process
|
||||
leaky_relu = mod[1]
|
||||
mod = mod[0]
|
||||
weight_post_process = mod.qconfig.weight()
|
||||
weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator]
|
||||
weight_post_process(mod.weight)
|
||||
dtype = weight_post_process.dtype
|
||||
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
||||
@ -108,7 +108,7 @@ class LinearLeakyReLU(nnq.Linear):
|
||||
qlinear_leaky_relu = cls(
|
||||
mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
|
||||
)
|
||||
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
|
||||
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
|
||||
qlinear_leaky_relu.scale = float(act_scale)
|
||||
qlinear_leaky_relu.zero_point = int(act_zp)
|
||||
return qlinear_leaky_relu
|
||||
@ -164,14 +164,14 @@ class LinearTanh(nnq.Linear):
|
||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
activation_post_process = mod.activation_post_process
|
||||
mod = mod[0]
|
||||
weight_post_process = mod.qconfig.weight()
|
||||
weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator]
|
||||
weight_post_process(mod.weight)
|
||||
dtype = weight_post_process.dtype
|
||||
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
|
||||
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
|
||||
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
||||
qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
|
||||
qlinear_tanh.set_weight_bias(qweight, mod.bias)
|
||||
qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
|
||||
qlinear_tanh.scale = float(act_scale)
|
||||
qlinear_tanh.zero_point = int(act_zp)
|
||||
return qlinear_tanh
|
||||
|
@ -52,7 +52,7 @@ def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
|
||||
if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
return mod.weight.detach()
|
||||
elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)):
|
||||
return mod[0].weight.detach()
|
||||
return mod[0].weight.detach() # type: ignore[operator]
|
||||
else:
|
||||
return mod._weight_bias()[0] # type: ignore[operator]
|
||||
|
||||
@ -61,7 +61,7 @@ def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
|
||||
if isinstance(mod, nn.Linear):
|
||||
return mod.weight.detach()
|
||||
elif isinstance(mod, nni.LinearReLU):
|
||||
return mod[0].weight.detach()
|
||||
return mod[0].weight.detach() # type: ignore[operator]
|
||||
else:
|
||||
return mod._weight_bias()[0] # type: ignore[operator]
|
||||
|
||||
@ -79,8 +79,12 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]:
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
||||
res = []
|
||||
for weight_value in mod._all_weight_values:
|
||||
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
||||
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
|
||||
res.append(
|
||||
weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0] # type: ignore[index]
|
||||
)
|
||||
res.append(
|
||||
weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0] # type: ignore[index]
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ class AdaroundFakeQuantizer(FakeQuantize):
|
||||
self.use_soft_rounding = True
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
|
||||
return self.scale, self.zero_point
|
||||
|
||||
@torch.jit.export
|
||||
|
@ -392,7 +392,7 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
|
||||
return self.activation_post_process.calculate_qparams()
|
||||
|
||||
@torch.jit.export
|
||||
|
@ -104,7 +104,8 @@ def _get_lstm_with_individually_observed_parts(
|
||||
# Insert observers into each LSTM cell
|
||||
# TODO: maybe make this work for layer_bw as well
|
||||
for layer in quantizable_lstm.layers:
|
||||
cell = layer.layer_fw.cell
|
||||
cell = layer.layer_fw.cell # type: ignore[union-attr]
|
||||
assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module"
|
||||
cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
|
||||
# HACK: Manually replace the activation_post_process following these ops.
|
||||
# This is needed for FloatFunctional ops because there is currently no way
|
||||
@ -154,7 +155,7 @@ def _get_lstm_with_individually_observed_parts(
|
||||
setattr(
|
||||
cell, activation_post_process_name, activation_post_process_ctr()
|
||||
)
|
||||
layer.layer_fw.cell = cell
|
||||
layer.layer_fw.cell = cell # type: ignore[union-attr]
|
||||
return quantizable_lstm
|
||||
|
||||
|
||||
@ -216,5 +217,5 @@ def _get_reference_quantized_lstm_module(
|
||||
node.replace_input_with(arg, arg.args[0])
|
||||
cell.graph.eliminate_dead_code()
|
||||
cell.recompile()
|
||||
layer.layer_fw.cell = cell
|
||||
layer.layer_fw.cell = cell # type: ignore[union-attr]
|
||||
return quantized_lstm
|
||||
|
@ -172,9 +172,9 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
self.head.reset_parameters()
|
||||
for i2h, h2o in self.tail:
|
||||
i2h.reset_parameters()
|
||||
h2o.reset_parameters()
|
||||
for i2h, h2o in self.tail: # type: ignore[misc]
|
||||
i2h.reset_parameters() # type: ignore[has-type]
|
||||
h2o.reset_parameters() # type: ignore[has-type]
|
||||
|
||||
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
|
||||
targ_dim = target_.dim()
|
||||
|
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
@ -29,6 +28,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
T = TypeVar("T", bound=Module)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
|
||||
@ -121,7 +121,7 @@ class Sequential(Module):
|
||||
for idx, module in enumerate(args):
|
||||
self.add_module(str(idx), module)
|
||||
|
||||
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
|
||||
def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
|
||||
"""Get the idx-th item of the iterator."""
|
||||
size = len(self)
|
||||
idx = operator.index(idx)
|
||||
@ -131,7 +131,7 @@ class Sequential(Module):
|
||||
return next(islice(iterator, idx, None))
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, T]:
|
||||
def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]:
|
||||
if isinstance(idx, slice):
|
||||
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
||||
else:
|
||||
@ -227,7 +227,7 @@ class Sequential(Module):
|
||||
return self
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __dir__(self):
|
||||
def __dir__(self) -> list[str]:
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
@ -410,7 +410,7 @@ class ModuleList(Module):
|
||||
combined.add_module(str(i), module)
|
||||
return combined
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return a custom repr for ModuleList that compresses repeated module representations."""
|
||||
list_of_reprs = [repr(item) for item in self]
|
||||
if len(list_of_reprs) == 0:
|
||||
@ -443,7 +443,7 @@ class ModuleList(Module):
|
||||
return main_str
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __dir__(self):
|
||||
def __dir__(self) -> list[str]:
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
@ -580,17 +580,17 @@ class ModuleDict(Module):
|
||||
return v
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def keys(self) -> Iterable[str]:
|
||||
def keys(self) -> container_abcs.KeysView[str]:
|
||||
r"""Return an iterable of the ModuleDict keys."""
|
||||
return self._modules.keys()
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def items(self) -> Iterable[tuple[str, Module]]:
|
||||
def items(self) -> container_abcs.ItemsView[str, Module]:
|
||||
r"""Return an iterable of the ModuleDict key/value pairs."""
|
||||
return self._modules.items()
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def values(self) -> Iterable[Module]:
|
||||
def values(self) -> container_abcs.ValuesView[Module]:
|
||||
r"""Return an iterable of the ModuleDict values."""
|
||||
return self._modules.values()
|
||||
|
||||
@ -716,7 +716,7 @@ class ParameterList(Module):
|
||||
def __iadd__(self, parameters: Iterable[Any]) -> Self:
|
||||
return self.extend(parameters)
|
||||
|
||||
def __dir__(self):
|
||||
def __dir__(self) -> list[str]:
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
@ -930,7 +930,7 @@ class ParameterDict(Module):
|
||||
"""
|
||||
return ParameterDict((k, default) for k in keys)
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
def keys(self) -> container_abcs.KeysView[str]:
|
||||
r"""Return an iterable of the ParameterDict keys."""
|
||||
return self._keys.keys()
|
||||
|
||||
|
@ -594,9 +594,9 @@ def register_parametrization(
|
||||
|
||||
# add the new parametrization to the parametrization list
|
||||
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
|
||||
module.parametrizations[tensor_name].append(parametrization)
|
||||
module.parametrizations[tensor_name].append(parametrization) # type: ignore[operator]
|
||||
# If unsafe was True in previous parametrization, keep it enabled
|
||||
module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr]
|
||||
module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr, operator]
|
||||
elif tensor_name in module._buffers or tensor_name in module._parameters:
|
||||
# Set the parametrization mechanism
|
||||
# Fetch the original buffer or parameter
|
||||
@ -686,6 +686,7 @@ def remove_parametrizations(
|
||||
parametrizations = module.parametrizations[tensor_name]
|
||||
if parametrizations.is_tensor:
|
||||
original = parametrizations.original
|
||||
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
|
||||
if leave_parametrized:
|
||||
with torch.no_grad():
|
||||
t = getattr(module, tensor_name)
|
||||
@ -792,7 +793,9 @@ def transfer_parametrizations_and_params(
|
||||
)
|
||||
|
||||
# apply the params's parametrizations to to_module
|
||||
for param_func in from_module.parametrizations[parameter_name]:
|
||||
for param_func in from_module.parametrizations[ # type: ignore[attr-defined]
|
||||
parameter_name
|
||||
]:
|
||||
register_parametrization(to_module, parameter_name, param_func)
|
||||
assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
|
||||
|
||||
|
Reference in New Issue
Block a user