mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
38caf10411
commit
31715be72a
@ -90,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==1.10.0
|
||||
mypy==1.11.2
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
#Description: linter
|
||||
#Pinned versions: 1.10.0
|
||||
|
@ -139,7 +139,7 @@ init_command = [
|
||||
'numpy==1.24.3 ; python_version == "3.8"',
|
||||
'numpy==1.26.0 ; python_version >= "3.9"',
|
||||
'expecttest==0.2.1',
|
||||
'mypy==1.10.0',
|
||||
'mypy==1.11.2',
|
||||
'sympy==1.12.1 ; python_version == "3.8"',
|
||||
'sympy==1.13.0 ; python_version >= "3.9"',
|
||||
'types-requests==2.27.25',
|
||||
|
@ -521,7 +521,7 @@ ACCURACY_FAILS: Dict[str, Callable[[nn.Module, Any], bool]] = {
|
||||
def repro_minifier_query(options, mod, load_args):
|
||||
mod, args = repro_common(options, mod, load_args)
|
||||
fail_fn = functools.partial(
|
||||
ACCURACY_FAILS[options.accuracy], check_str=options.check_str
|
||||
ACCURACY_FAILS[options.accuracy], check_str=options.check_str # type: ignore[call-arg]
|
||||
)
|
||||
if fail_fn(mod, args):
|
||||
sys.exit(1)
|
||||
|
@ -12,6 +12,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.backends.registry import CompiledFn
|
||||
from torch._dynamo.debug_utils import (
|
||||
AccuracyError,
|
||||
backend_accuracy_fails,
|
||||
@ -271,8 +272,10 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name):
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
|
||||
@register_debug_backend
|
||||
def dynamo_minifier_backend(gm, example_inputs, compiler_name):
|
||||
@register_debug_backend # type: ignore[arg-type]
|
||||
def dynamo_minifier_backend(
|
||||
gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn
|
||||
):
|
||||
from functorch.compile import minifier
|
||||
|
||||
compiler_fn = lookup_backend(compiler_name)
|
||||
@ -311,7 +314,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name):
|
||||
return gm
|
||||
|
||||
|
||||
@register_debug_backend
|
||||
@register_debug_backend # type: ignore[arg-type]
|
||||
def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name):
|
||||
from functorch.compile import minifier
|
||||
|
||||
|
@ -471,14 +471,14 @@ class ContinueExecutionCache:
|
||||
code,
|
||||
lineno,
|
||||
offset: int,
|
||||
setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+
|
||||
setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+
|
||||
nstack: int,
|
||||
argnames: Tuple[str],
|
||||
argnames_null: Tuple[str],
|
||||
setup_fns: Tuple[ReenterWith],
|
||||
stack_ctx_vars: Tuple[int, Tuple[Any]],
|
||||
argnames_ctx_vars: Tuple[str, Tuple[Any]],
|
||||
null_idxes: Tuple[int],
|
||||
argnames: Tuple[str, ...],
|
||||
argnames_null: Tuple[str, ...],
|
||||
setup_fns: Tuple[ReenterWith, ...],
|
||||
stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...],
|
||||
argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...],
|
||||
null_idxes: Tuple[int, ...],
|
||||
) -> types.CodeType:
|
||||
assert offset is not None
|
||||
assert not (
|
||||
|
@ -272,7 +272,7 @@ class EphemeralSource(Source):
|
||||
def name(self):
|
||||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||
|
||||
def make_guard(self):
|
||||
def make_guard(self, fn):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_ephemeral(self):
|
||||
|
@ -3198,7 +3198,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
def run_ctx_mgr(self):
|
||||
return TracingContext.current_frame(self.parent.frame_summary())
|
||||
|
||||
def STORE_DEREF(self, inst):
|
||||
def STORE_DEREF(self, inst): # type: ignore[override]
|
||||
if inst.argval in self.closure_cells:
|
||||
cell = self.closure_cells[inst.argval]
|
||||
val = self.pop()
|
||||
|
@ -2915,6 +2915,8 @@ def is_frozen_dataclass(value):
|
||||
not object_has_getattribute(value)
|
||||
and not class_has_getattribute(value)
|
||||
and is_dataclass(value)
|
||||
and hasattr(value, "__dataclass_params__")
|
||||
and hasattr(value.__dataclass_params__, "frozen")
|
||||
and value.__dataclass_params__.frozen
|
||||
)
|
||||
|
||||
|
@ -195,7 +195,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
||||
self.node_replacements[node] = tensor
|
||||
|
||||
def run(self):
|
||||
def run(self): # type: ignore[override]
|
||||
env = {}
|
||||
for n in self.module.graph.find_nodes(op="placeholder"):
|
||||
env[n] = self.unknown_value
|
||||
|
@ -77,7 +77,7 @@ class TransformGetItemToIndex(TorchFunctionMode):
|
||||
# scalar and create a view. We do not want that behavior in this case, so we
|
||||
# use this torchfunctionmode to override that behavior for score_mod
|
||||
# wherever we're running it.
|
||||
def __torch_function__(self, func, types, args, kwargs=None):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if func == torch.Tensor.__getitem__:
|
||||
index_args = pytree.tree_leaves(args[1])
|
||||
if all(isinstance(x, torch.Tensor) for x in index_args):
|
||||
|
@ -1480,7 +1480,8 @@ class CppWrapperCpu(WrapperCodeGen):
|
||||
'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef<c10::IValue>());'
|
||||
)
|
||||
|
||||
def write_triton_header_once(self):
|
||||
@cache_on_self
|
||||
def write_triton_header_once(self) -> None:
|
||||
pass
|
||||
|
||||
def generate_start_graph(self):
|
||||
|
@ -30,7 +30,7 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
|
||||
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
|
||||
|
||||
def get_backend_features(self, device):
|
||||
def get_backend_features(self, device): # type:ignore[override]
|
||||
return self._triton_scheduling.get_backend_features(device)
|
||||
|
||||
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
|
||||
|
@ -578,22 +578,22 @@ def extract_read_writes(
|
||||
repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)}
|
||||
name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()}
|
||||
for entry in fn.memory_usage[MemoryUsageType.LOAD]:
|
||||
inner.load(entry.buffer_name, name_to_index[entry.index_name])
|
||||
inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type]
|
||||
for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
|
||||
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name]))
|
||||
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
|
||||
for entry in fn.memory_usage[MemoryUsageType.STORE]:
|
||||
inner.store(
|
||||
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode
|
||||
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type]
|
||||
)
|
||||
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
|
||||
inner.store_reduction(
|
||||
entry.buffer_name, name_to_index[entry.index_name], None
|
||||
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type]
|
||||
)
|
||||
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
||||
inner.index_expr(name_to_index[entry.index_name], None)
|
||||
for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
|
||||
inner.bucketize(
|
||||
None, entry.buffer_name, name_to_index[entry.index_name], None, None
|
||||
None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type]
|
||||
)
|
||||
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
||||
else:
|
||||
|
@ -2,7 +2,7 @@
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import cast, List, Tuple
|
||||
from typing import cast, Sequence, Tuple
|
||||
|
||||
import sympy
|
||||
|
||||
@ -28,7 +28,7 @@ def filtered_configs(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
configs: List[Tuple[int, int, int, int, int]],
|
||||
configs: Sequence[Tuple[int, int, int, int, int]],
|
||||
has_int8_tensor=False,
|
||||
):
|
||||
"""Heuristic to shrink configs when they are bigger than the input size"""
|
||||
|
@ -101,8 +101,8 @@ class RemoteCache(Generic[_T]):
|
||||
self._put(key, value, sample)
|
||||
self._log_sample(sample)
|
||||
|
||||
def _decode(self, data: _U, sample: Optional[Sample]) -> _T:
|
||||
return self.serde.decode(data)
|
||||
def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override]
|
||||
return self.serde.decode(data) # type: ignore[arg-type]
|
||||
|
||||
def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U
|
||||
return self.serde.encode(value)
|
||||
|
@ -827,7 +827,7 @@ class CachingAutotuner(KernelInterface):
|
||||
)
|
||||
return config2launcher.get(best_config)
|
||||
|
||||
def run(self, *args, grid, stream, **kwargs):
|
||||
def run(self, *args, grid, stream, **kwargs): # type:ignore[override]
|
||||
if len(self.launchers) != 1:
|
||||
if len(self.launchers) == 0:
|
||||
start_time = time.time_ns()
|
||||
@ -958,7 +958,7 @@ class DebugAutotuner(CachingAutotuner):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cached = None
|
||||
|
||||
def run(self, *args, grid, stream):
|
||||
def run(self, *args, grid, stream): # type: ignore[override]
|
||||
possible_names = _find_names(self)
|
||||
kernel_name = f"{max(possible_names, key=len)}"
|
||||
if not re.match(self.regex_filter, kernel_name):
|
||||
|
@ -2857,9 +2857,7 @@ class Scheduler:
|
||||
return False
|
||||
|
||||
# Pick the largest buffer to guide the loop reordering
|
||||
numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[
|
||||
0
|
||||
]
|
||||
numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0])
|
||||
|
||||
if lhs_dep.num_vars != rhs_dep.num_vars:
|
||||
# this can happen due to we don't merge loops.
|
||||
|
@ -268,4 +268,4 @@ def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.L
|
||||
elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type]
|
||||
return typing.List[type_args[0]] # type: ignore[valid-type]
|
||||
else:
|
||||
return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc]
|
||||
return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value]
|
||||
|
@ -221,7 +221,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
"Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
return f"FunctionalTensor({repr(self.elem)})"
|
||||
|
||||
@staticmethod
|
||||
@ -302,7 +302,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
long = _conversion_method_template(dtype=torch.int64)
|
||||
|
||||
# TODO(sparse-team): fixes #133174 but can we do without the relay?
|
||||
def to_dense(self):
|
||||
def to_dense(self): # type: ignore[override]
|
||||
return self.elem.to_dense()
|
||||
|
||||
@property
|
||||
|
@ -1353,7 +1353,7 @@ class Tensor(torch._C.TensorBase):
|
||||
[name for name in names if not is_ellipsis(name)], ellipsis_idx
|
||||
)
|
||||
|
||||
def unflatten(self, dim, sizes):
|
||||
def unflatten(self, dim, sizes): # type: ignore[override]
|
||||
r"""
|
||||
unflatten(dim, sizes) -> Tensor
|
||||
|
||||
|
@ -228,7 +228,7 @@ class ConvAdd2d(_FusedModule):
|
||||
super().__init__(conv)
|
||||
self.add = add
|
||||
|
||||
def forward(self, x1, x2):
|
||||
def forward(self, x1, x2): # type: ignore[override]
|
||||
return self.add(self[0](x1), x2)
|
||||
|
||||
|
||||
@ -241,5 +241,5 @@ class ConvAddReLU2d(_FusedModule):
|
||||
self.add = add
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x1, x2):
|
||||
def forward(self, x1, x2): # type: ignore[override]
|
||||
return self.relu(self.add(self[0](x1), x2))
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import TypeVar
|
||||
from typing import ClassVar, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.intrinsic as nni
|
||||
@ -33,12 +33,9 @@ _BN_CLASS_MAP = {
|
||||
}
|
||||
|
||||
|
||||
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
|
||||
|
||||
|
||||
class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
||||
_version = 2
|
||||
_FLOAT_MODULE = MOD
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -365,7 +362,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
assert mod.qconfig, "Input float module must have a valid qconfig"
|
||||
qconfig = mod.qconfig
|
||||
conv, bn = mod[0], mod[1]
|
||||
conv, bn = mod[0], mod[1] # type: ignore[index]
|
||||
qat_convbn = cls(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
@ -434,7 +431,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
||||
return conv
|
||||
|
||||
|
||||
class ConvBn1d(_ConvBnNd, nn.Conv1d):
|
||||
class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc]
|
||||
r"""
|
||||
A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
|
||||
attached with FakeQuantize modules for weight,
|
||||
@ -451,10 +448,10 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm1d
|
||||
_FLOAT_RELU_MODULE: None = None
|
||||
_FLOAT_MODULE = nni.ConvBn1d
|
||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||
_FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -520,12 +517,12 @@ class ConvBnReLU1d(ConvBn1d):
|
||||
|
||||
"""
|
||||
# base class defines _FLOAT_MODULE as "ConvBn1d"
|
||||
_FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm1d
|
||||
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBnReLU1d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
_FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU
|
||||
# module class after fusing bn into conv
|
||||
_FUSED_FLOAT_MODULE = nni.ConvReLU1d
|
||||
_FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -585,10 +582,10 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment, misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
_FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -631,7 +628,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||
)
|
||||
|
||||
|
||||
class ConvBn2d(_ConvBnNd, nn.Conv2d):
|
||||
class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc]
|
||||
r"""
|
||||
A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
|
||||
attached with FakeQuantize modules for weight,
|
||||
@ -648,10 +645,10 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvBn2d
|
||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm2d
|
||||
_FLOAT_RELU_MODULE: None = None
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
_FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm2d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -717,12 +714,12 @@ class ConvBnReLU2d(ConvBn2d):
|
||||
|
||||
"""
|
||||
# base class defines _FLOAT_MODULE as "ConvBn2d"
|
||||
_FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm2d
|
||||
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment, misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
_FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm2d]] = nn.BatchNorm2d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # type: ignore[assignment,misc]
|
||||
# module class after fusing bn into conv
|
||||
_FUSED_FLOAT_MODULE = nni.ConvReLU2d
|
||||
_FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU2d]]] = nni.ConvReLU2d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -782,10 +779,10 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment, misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
_FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -828,7 +825,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||
)
|
||||
|
||||
|
||||
class ConvBn3d(_ConvBnNd, nn.Conv3d):
|
||||
class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc]
|
||||
r"""
|
||||
A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
|
||||
attached with FakeQuantize modules for weight,
|
||||
@ -845,10 +842,10 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvBn3d
|
||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm3d
|
||||
_FLOAT_RELU_MODULE: None = None
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
_FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm3d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -913,12 +910,12 @@ class ConvBnReLU3d(ConvBn3d):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||
_FLOAT_BN_MODULE = nn.BatchNorm3d
|
||||
_FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment, misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
_FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm3d]] = nn.BatchNorm3d
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.ReLU]]] = nn.ReLU # type: ignore[assignment, misc]
|
||||
# module class after fusing bn into conv
|
||||
_FUSED_FLOAT_MODULE = nni.ConvReLU3d
|
||||
_FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU3d]]] = nni.ConvReLU3d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -980,10 +977,10 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment]
|
||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||
_FLOAT_BN_MODULE: None = None
|
||||
_FLOAT_RELU_MODULE = nn.ReLU
|
||||
_FLOAT_MODULE: ClassVar[Type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
_FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -49,7 +49,7 @@ class ConvAdd2d(nnq.Conv2d):
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input, extra_input):
|
||||
def forward(self, input, extra_input): # type: ignore[override]
|
||||
# Temporarily using len(shape) instead of ndim due to JIT issue
|
||||
# https://github.com/pytorch/pytorch/issues/23890
|
||||
if len(input.shape) != 4:
|
||||
@ -117,7 +117,7 @@ class ConvAddReLU2d(nnq.Conv2d):
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input, extra_input):
|
||||
def forward(self, input, extra_input): # type: ignore[override]
|
||||
# Temporarily using len(shape) instead of ndim due to JIT issue
|
||||
# https://github.com/pytorch/pytorch/issues/23890
|
||||
if len(input.shape) != 4:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Tuple, TypeVar, Union
|
||||
from typing import ClassVar, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,11 +10,9 @@ from torch.nn.modules.utils import _pair, _single, _triple
|
||||
|
||||
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
|
||||
|
||||
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
|
||||
|
||||
|
||||
class _ConvNd(nn.modules.conv._ConvNd):
|
||||
_FLOAT_MODULE = MOD
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -136,8 +134,8 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
||||
Attributes:
|
||||
weight_fake_quant: fake quant module for weight
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv1d
|
||||
_FLOAT_CONV_MODULE = nn.Conv1d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -197,8 +195,8 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||
Attributes:
|
||||
weight_fake_quant: fake quant module for weight
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv2d
|
||||
_FLOAT_CONV_MODULE = nn.Conv2d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -261,8 +259,8 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||
Attributes:
|
||||
weight_fake_quant: fake quant module for weight
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv3d
|
||||
_FLOAT_CONV_MODULE = nn.Conv3d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc]
|
||||
_FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -2,6 +2,7 @@
|
||||
r"""Dynamically quantized convolution modules."""
|
||||
|
||||
import warnings
|
||||
from typing import ClassVar, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.quantized as nnq
|
||||
@ -47,9 +48,9 @@ class Conv1d(nnq.Conv1d):
|
||||
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.Conv1d
|
||||
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
|
||||
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -132,9 +133,9 @@ class Conv2d(nnq.Conv2d):
|
||||
>>> output = m(input)
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv2d
|
||||
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
|
||||
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -216,9 +217,9 @@ class Conv3d(nnq.Conv3d):
|
||||
>>> output = m(input)
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv3d
|
||||
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
|
||||
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -308,7 +309,7 @@ class ConvTranspose1d(nnq.ConvTranspose1d):
|
||||
torch.Size([1, 16, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose1d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -390,7 +391,7 @@ class ConvTranspose2d(nnq.ConvTranspose2d):
|
||||
torch.Size([1, 16, 12, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose2d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -472,7 +473,7 @@ class ConvTranspose3d(nnq.ConvTranspose3d):
|
||||
torch.Size([1, 16, 12, 12, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose3d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Quantized convolution modules."""
|
||||
|
||||
from typing import List, Optional, TypeVar
|
||||
from typing import ClassVar, List, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.intrinsic as nni
|
||||
@ -386,11 +386,11 @@ class Conv1d(_ConvNd):
|
||||
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.Conv1d
|
||||
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
|
||||
_NNI_CONV_RELU_MODULE = nni.ConvReLU1d
|
||||
_NNI_CONV_ADD_MODULE: None = None
|
||||
_NNI_CONV_ADD_RELU_MODULE: None = None
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn1d
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d
|
||||
_NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -518,11 +518,11 @@ class Conv2d(_ConvNd):
|
||||
>>> output = m(q_input)
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv2d
|
||||
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
|
||||
_NNI_CONV_RELU_MODULE = nni.ConvReLU2d
|
||||
_NNI_CONV_ADD_MODULE = nni.ConvAdd2d
|
||||
_NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn2d
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU2d
|
||||
_NNI_CONV_ADD_MODULE: ClassVar[Type[nni.ConvAdd2d]] = nni.ConvAdd2d
|
||||
_NNI_CONV_ADD_RELU_MODULE: ClassVar[Type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -647,11 +647,11 @@ class Conv3d(_ConvNd):
|
||||
>>> output = m(q_input)
|
||||
|
||||
"""
|
||||
_FLOAT_MODULE = nn.Conv3d
|
||||
_NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
|
||||
_NNI_CONV_RELU_MODULE = nni.ConvReLU3d
|
||||
_NNI_CONV_ADD_MODULE: None = None
|
||||
_NNI_CONV_ADD_RELU_MODULE: None = None
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d
|
||||
_NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn3d
|
||||
_NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU3d
|
||||
_NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
_NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -740,11 +740,10 @@ class Conv3d(_ConvNd):
|
||||
|
||||
|
||||
# === Transposed Convolutions ===
|
||||
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
|
||||
|
||||
|
||||
class _ConvTransposeNd(_ConvNd):
|
||||
_FLOAT_MODULE = MOD
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -914,7 +913,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
||||
torch.Size([1, 16, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose1d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1037,7 +1036,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
||||
torch.Size([1, 16, 12, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose2d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1162,7 +1161,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
||||
torch.Size([1, 16, 12, 12, 12])
|
||||
"""
|
||||
|
||||
_FLOAT_MODULE = nn.ConvTranspose3d
|
||||
_FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -198,7 +198,7 @@ class ShadowLogger(Logger):
|
||||
self.stats["float"] = []
|
||||
self.stats["quantized"] = []
|
||||
|
||||
def forward(self, x, y):
|
||||
def forward(self, x, y): # type: ignore[override]
|
||||
# fmt: off
|
||||
"""
|
||||
""" # blank docblock to make autodoc happy
|
||||
|
@ -261,7 +261,7 @@ class OutputComparisonLogger(OutputLogger):
|
||||
self.comparisons = []
|
||||
# precalculated comparisons function
|
||||
|
||||
def forward(self, x, x_ref):
|
||||
def forward(self, x, x_ref): # type: ignore[override]
|
||||
# fmt: off
|
||||
"""
|
||||
""" # blank docblock to make autodoc happy
|
||||
|
@ -410,7 +410,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
||||
# Add function swaps from default lowering path
|
||||
#
|
||||
|
||||
for source, (
|
||||
for source, ( # type:ignore[assignment]
|
||||
target1,
|
||||
target2,
|
||||
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
||||
@ -422,7 +422,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
||||
_lower_to_native_backend.QBIN_RELU_OP_MAPPING,
|
||||
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
||||
):
|
||||
for source, target in source_to_target.items():
|
||||
for source, target in source_to_target.items(): # type:ignore[assignment]
|
||||
new_connections.append((source, target))
|
||||
|
||||
#
|
||||
@ -432,7 +432,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
||||
for source_to_target in (
|
||||
quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
|
||||
):
|
||||
for source, target in source_to_target.items():
|
||||
for source, target in source_to_target.items(): # type:ignore[assignment]
|
||||
new_connections.append((source, target))
|
||||
|
||||
# add the new connections from backend_config
|
||||
|
@ -71,7 +71,7 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
||||
# add data with default config here
|
||||
[self.add_data(name, data, **self.defaults) for name, data in data_list]
|
||||
|
||||
def prepare(self):
|
||||
def prepare(self, model, config):
|
||||
raise NotImplementedError("this function is undefined for this class")
|
||||
|
||||
def _extract_weight(self, data):
|
||||
@ -266,7 +266,7 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
||||
"_container": self._container.state_dict(),
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type:ignore[override]
|
||||
format_string = self.__class__.__name__ + " ("
|
||||
for name, sparse_args in self.data_groups.items():
|
||||
format_string += "\n"
|
||||
@ -299,7 +299,7 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
||||
self._container, name, leave_parametrized=leave_parametrized
|
||||
)
|
||||
|
||||
def step(self):
|
||||
def step(self): # type:ignore[override]
|
||||
if not self.enable_mask_update:
|
||||
return
|
||||
with torch.no_grad():
|
||||
|
@ -156,7 +156,7 @@ class DataNormSparsifier(BaseDataSparsifier):
|
||||
] # squeeze only the first 2 dimension
|
||||
return mask
|
||||
|
||||
def update_mask(
|
||||
def update_mask( # type: ignore[override]
|
||||
self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs
|
||||
):
|
||||
values_per_block = reduce(operator.mul, sparse_block_shape)
|
||||
|
@ -75,7 +75,9 @@ class FPGMPruner(BaseStructuredSparsifier):
|
||||
|
||||
return distance
|
||||
|
||||
def update_mask(self, module, tensor_name, sparsity_level, **kwargs):
|
||||
def update_mask( # type: ignore[override]
|
||||
self, module, tensor_name, sparsity_level, **kwargs
|
||||
):
|
||||
tensor_weight = getattr(module, tensor_name)
|
||||
mask = getattr(module.parametrizations, tensor_name)[0].mask
|
||||
|
||||
|
@ -33,7 +33,9 @@ class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier):
|
||||
defaults = {"nearliness": nearliness}
|
||||
super().__init__(defaults=defaults)
|
||||
|
||||
def update_mask(self, module, tensor_name, nearliness, **kwargs):
|
||||
def update_mask( # type:ignore[override]
|
||||
self, module, tensor_name, nearliness, **kwargs
|
||||
):
|
||||
mask = getattr(module.parametrizations, tensor_name)[0].mask
|
||||
mask.data = torch.zeros_like(mask)
|
||||
if nearliness <= 0:
|
||||
|
@ -208,7 +208,7 @@ class WeightNormSparsifier(BaseSparsifier):
|
||||
mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
|
||||
return mask
|
||||
|
||||
def update_mask(
|
||||
def update_mask( # type: ignore[call-override, override]
|
||||
self,
|
||||
module,
|
||||
tensor_name,
|
||||
|
@ -216,5 +216,5 @@ class _DerivedObserverOrFakeQuantize(ObserverBase):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x
|
||||
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type:ignore[override]
|
||||
return self.derive_qparams_fn(self.obs_or_fqs)
|
||||
|
@ -61,7 +61,7 @@ class MeanShadowLogger(ns.Logger):
|
||||
self.float_sum = None
|
||||
self.quant_sum = None
|
||||
|
||||
def forward(self, x, y):
|
||||
def forward(self, x, y): # type: ignore[override]
|
||||
"""Compute the average of quantized and floating-point data from modules.
|
||||
|
||||
The inputs x,y are output data from the quantized and floating-point modules.
|
||||
|
@ -34,7 +34,7 @@ class APoTObserver(ObserverBase):
|
||||
|
||||
# min_val and max_val are optional args to override
|
||||
# the min_val and max_val observed by forward
|
||||
def calculate_qparams(self, signed):
|
||||
def calculate_qparams(self, signed): # type:ignore[override]
|
||||
return self._calculate_qparams(signed, self.min_val, self.max_val)
|
||||
|
||||
r""" Calculates nonuniform quantization parameters according to APoT paper:
|
||||
|
@ -225,7 +225,7 @@ def _map_module_function_to_aten_operator_type():
|
||||
),
|
||||
)
|
||||
for map_item in map_list:
|
||||
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload]
|
||||
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload]
|
||||
return module_function_to_aten_operator
|
||||
|
||||
|
||||
|
@ -600,7 +600,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
elem = inner_tensors["elem"]
|
||||
return AsyncCollectiveTensor(elem)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
return f"AsyncCollectiveTensor({self.trigger_wait()})"
|
||||
|
||||
def trigger_wait(self):
|
||||
@ -653,7 +653,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
|
||||
return out
|
||||
|
||||
def numpy(self):
|
||||
def numpy(self): # type: ignore[override]
|
||||
return self.wait().numpy()
|
||||
|
||||
|
||||
|
@ -1200,7 +1200,7 @@ class ShardedTensor(ShardedTensorBase):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
return f"ShardedTensor({self._metadata})"
|
||||
|
||||
@dataclass
|
||||
|
@ -136,7 +136,7 @@ def _low_precision_hook(
|
||||
prec: torch.dtype,
|
||||
state: LowPrecisionState,
|
||||
grad: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output: Optional[torch.Tensor],
|
||||
):
|
||||
if grad.dtype != prec:
|
||||
grad.data = grad.data.to(prec)
|
||||
|
@ -96,7 +96,7 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
||||
)
|
||||
self.averager.step = 0
|
||||
|
||||
def step(self):
|
||||
def step(self): # type: ignore[override]
|
||||
r"""
|
||||
Performs a single optimization step (parameter update).
|
||||
"""
|
||||
|
@ -364,7 +364,7 @@ class DetachExecutor(fx.Interpreter):
|
||||
super().__init__(module, garbage_collect_values)
|
||||
self.value_remap = {}
|
||||
|
||||
def run(self, *args, initial_env=None):
|
||||
def run(self, *args, initial_env=None): # type: ignore[override]
|
||||
self.value_remap = {}
|
||||
return super().run(*args, initial_env=initial_env)
|
||||
|
||||
@ -932,8 +932,7 @@ class Pipe(torch.nn.Module):
|
||||
if node.op == "get_attr":
|
||||
# get_attr might get access deeper level attribute
|
||||
fqn = scope + "." + node.target if scope else node.target
|
||||
if fqn in unused_attributes: # used, remove it
|
||||
unused_attributes.remove(fqn)
|
||||
unused_attributes.discard(fqn)
|
||||
for _name, _submod in _mod.named_children():
|
||||
stack.append((scope + "." + _name if scope else _name, _submod))
|
||||
# delete unused attributes
|
||||
|
@ -283,7 +283,7 @@ class DTensor(torch.Tensor):
|
||||
|
||||
# pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# TODO: consider all_gather the local tensors for better debugging
|
||||
return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
|
||||
|
||||
|
@ -309,7 +309,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
||||
|
||||
# pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
@ -39,6 +39,9 @@ def input_reshard(
|
||||
Return:
|
||||
A :class:`nn.Module` object registered with TP input resharding.
|
||||
"""
|
||||
if input_reshard_dim is None:
|
||||
return module
|
||||
|
||||
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
|
||||
|
||||
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
|
||||
@ -56,8 +59,6 @@ def input_reshard(
|
||||
nonlocal cx
|
||||
cx.__exit__() # type: ignore[name-defined, union-attr]
|
||||
|
||||
if input_reshard_dim is None:
|
||||
return module
|
||||
module.register_forward_pre_hook(input_reshard_forward_pre_hook)
|
||||
module.register_forward_hook(input_reshard_backward_hook)
|
||||
return module
|
||||
|
@ -204,7 +204,7 @@ class _DependentProperty(property, _Dependent):
|
||||
self._is_discrete = is_discrete
|
||||
self._event_dim = event_dim
|
||||
|
||||
def __call__(self, fn):
|
||||
def __call__(self, fn): # type: ignore[override]
|
||||
"""
|
||||
Support for syntax to customize static attributes::
|
||||
|
||||
|
@ -177,7 +177,7 @@ class VonMises(Distribution):
|
||||
self._loc, self._concentration, self._proposal_r, x
|
||||
).to(self.loc.dtype)
|
||||
|
||||
def expand(self, batch_shape):
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
try:
|
||||
return super().expand(batch_shape)
|
||||
except NotImplementedError:
|
||||
|
@ -258,7 +258,7 @@ class MaskedTensor(torch.Tensor):
|
||||
self._masked_mask = mask
|
||||
self._validate_members()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type: ignore[override]
|
||||
formatter = "{0:8.4f}"
|
||||
if self.dim() == 0:
|
||||
scalar_data = self.get_data().item()
|
||||
@ -350,7 +350,7 @@ class MaskedTensor(torch.Tensor):
|
||||
def is_sparse_coo(self):
|
||||
return self.layout == torch.sparse_coo
|
||||
|
||||
def is_sparse_csr(self):
|
||||
def is_sparse_csr(self): # type: ignore[override]
|
||||
return self.layout == torch.sparse_csr
|
||||
|
||||
# Update later to support more sparse layouts
|
||||
|
@ -74,7 +74,7 @@ class SharedCache(dict):
|
||||
def _after_fork(self):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get(self, key):
|
||||
def get(self, key): # type: ignore[override]
|
||||
with self.lock:
|
||||
return dict.get(self, key)
|
||||
|
||||
|
@ -209,7 +209,7 @@ class NestedTensor(torch.Tensor):
|
||||
def _min_seqlen(self):
|
||||
return self._get_min_seqlen()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
|
@ -289,7 +289,7 @@ class CausalBias(torch.Tensor):
|
||||
)
|
||||
return cls._dispatch(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self): # type:ignore[override]
|
||||
return self._materialize().__repr__()
|
||||
|
||||
|
||||
|
@ -910,7 +910,7 @@ class SequentialLR(LRScheduler):
|
||||
|
||||
self._last_lr = schedulers[0].get_last_lr()
|
||||
|
||||
def step(self):
|
||||
def step(self): # type: ignore[override]
|
||||
"""Perform a step."""
|
||||
self.last_epoch += 1
|
||||
idx = bisect_right(self._milestones, self.last_epoch)
|
||||
@ -1179,7 +1179,7 @@ class ChainedScheduler(LRScheduler):
|
||||
group["lr"] for group in self._schedulers[-1].optimizer.param_groups
|
||||
]
|
||||
|
||||
def step(self):
|
||||
def step(self): # type: ignore[override]
|
||||
"""Perform a step."""
|
||||
for scheduler in self._schedulers:
|
||||
scheduler.step()
|
||||
|
@ -295,7 +295,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
else:
|
||||
return dense_input
|
||||
|
||||
def to_dense(self):
|
||||
def to_dense(self): # type:ignore[override]
|
||||
col = self.shape[-1]
|
||||
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
|
||||
|
||||
@ -420,7 +420,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||
requires_grad=original_tensor.requires_grad,
|
||||
)
|
||||
|
||||
def to_dense(self):
|
||||
def to_dense(self): # type: ignore[override]
|
||||
assert self.meta is not None and self.packed is not None
|
||||
return (
|
||||
sparse_semi_structured_to_dense_cutlass(
|
||||
|
@ -310,11 +310,11 @@ class ModularIndexing(sympy.Function):
|
||||
if isinstance(base, FloorDiv):
|
||||
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_positive(self):
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined]
|
||||
|
||||
@ -329,14 +329,14 @@ class Where(sympy.Function):
|
||||
def _eval_is_integer(self):
|
||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
return (
|
||||
True
|
||||
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
|
||||
else None
|
||||
)
|
||||
|
||||
def _eval_is_positive(self):
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
@ -397,7 +397,7 @@ class PythonMod(sympy.Function):
|
||||
return S.Zero
|
||||
|
||||
# NB: args[1] for PythonMod
|
||||
def _eval_is_nonnegative(self):
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonpositive(self):
|
||||
@ -823,13 +823,13 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
||||
zero = S.Infinity
|
||||
identity = S.NegativeInfinity
|
||||
|
||||
def _eval_is_positive(self):
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self):
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
return fuzzy_and(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
@ -841,13 +841,13 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
|
||||
zero = S.NegativeInfinity
|
||||
identity = S.Infinity
|
||||
|
||||
def _eval_is_positive(self):
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self):
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
return fuzzy_or(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
|
@ -263,7 +263,7 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
def setdefault(self, key, default=None):
|
||||
return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
|
||||
|
||||
def update(self, dict=None, **kwargs):
|
||||
def update(self, dict=None, **kwargs): # type: ignore[override]
|
||||
d = self.data
|
||||
if dict is not None:
|
||||
if not hasattr(dict, "items"):
|
||||
|
Reference in New Issue
Block a user