mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821 Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce79056471
commit
e95e8eed0a
@ -90,10 +90,10 @@ librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==1.15.0
|
||||
mypy==1.16.0
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
#Description: linter
|
||||
#Pinned versions: 1.14.0
|
||||
#Pinned versions: 1.16.0
|
||||
#test that import: test_typing.py, test_type_hints.py
|
||||
|
||||
networkx==2.8.8
|
||||
|
@ -154,7 +154,7 @@ init_command = [
|
||||
'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'mypy==1.15.0',
|
||||
'mypy==1.16.0',
|
||||
'sympy==1.13.3',
|
||||
'types-requests==2.27.25',
|
||||
'types-PyYAML==6.0.7',
|
||||
|
@ -42,7 +42,7 @@ def extract(step: Step) -> Script | None:
|
||||
"bash": f"#!/usr/bin/env bash\nset -eo pipefail\n{run}",
|
||||
"sh": f"#!/usr/bin/env sh\nset -e\n{run}",
|
||||
}.get(shell, run)
|
||||
return {"extension": extension, "script": script}
|
||||
return {"extension": extension, "script": script} # type: ignore[typeddict-item]
|
||||
elif is_gh_script and gh_script is not None:
|
||||
return {"extension": ".js", "script": gh_script}
|
||||
else:
|
||||
|
@ -199,12 +199,12 @@ class DeviceGuard:
|
||||
|
||||
|
||||
class CudaInterface(DeviceInterface):
|
||||
device = torch.cuda.device
|
||||
device = torch.cuda.device # type: ignore[assignment]
|
||||
|
||||
# register Event and Stream class into the backend interface
|
||||
# make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
|
||||
Event = torch.cuda.Event
|
||||
Stream = torch.cuda.Stream
|
||||
Event = torch.cuda.Event # type: ignore[assignment]
|
||||
Stream = torch.cuda.Stream # type: ignore[assignment]
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
@ -297,9 +297,9 @@ else:
|
||||
|
||||
|
||||
class XpuInterface(DeviceInterface):
|
||||
device = torch.xpu.device
|
||||
Event = torch.xpu.Event
|
||||
Stream = torch.xpu.Stream
|
||||
device = torch.xpu.device # type: ignore[assignment]
|
||||
Event = torch.xpu.Event # type: ignore[assignment]
|
||||
Stream = torch.xpu.Stream # type: ignore[assignment]
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
|
@ -1113,7 +1113,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
# A small codegen optimization because we might have different
|
||||
# VariableTrackers that share the same source.
|
||||
list_idx = x.source.index
|
||||
list_idx = x.source.index # type: ignore[attr-defined]
|
||||
if list_idx not in visited:
|
||||
alias_name = self.new_var(
|
||||
f"{list_name}_ref"
|
||||
|
@ -663,7 +663,7 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
|
||||
trace_structured_artifact(
|
||||
f"get_{ty}_code_state",
|
||||
"string",
|
||||
lambda: render_code_state(_CODE_STATE),
|
||||
lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type]
|
||||
)
|
||||
set_feature_use("pgo", True)
|
||||
_INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
|
||||
|
@ -238,7 +238,7 @@ def write_view_information_to_args(
|
||||
write_single_view(
|
||||
f"_{arg_name}",
|
||||
kwargs[arg_name],
|
||||
arg_to_base_index.get(arg_name, None),
|
||||
arg_to_base_index.get(arg_name, None), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported type {arg_type}")
|
||||
@ -389,7 +389,7 @@ class AutoFunctionalizedV2(HigherOrderOperator):
|
||||
if isinstance(_mutable_op, HigherOrderOperator):
|
||||
_op_to_check = HopInstance(
|
||||
_mutable_op,
|
||||
SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema,
|
||||
SchemaHolder.from_tree_spec(kwargs.get("_op_schema", None)).schema, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
_op_to_check = _mutable_op
|
||||
@ -948,7 +948,7 @@ def auto_functionalized_v2_proxy(
|
||||
if _only_clone_these_bases is None:
|
||||
_only_clone_these_bases = tuple(range(len(all_bases)))
|
||||
|
||||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema
|
||||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema # type: ignore[arg-type]
|
||||
new_kwargs, _ = _generate_new_op_kwargs_from_bases(
|
||||
schema,
|
||||
{k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")},
|
||||
|
@ -4986,7 +4986,7 @@ class CppScheduling(BaseScheduling):
|
||||
layout=local_buffer_layout,
|
||||
)
|
||||
local_buffers.append(local_buffer_used)
|
||||
local_to_global_buffers[local_buffer_used.name] = []
|
||||
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
|
||||
local_to_global_buffers[local_buffer_used.name].append(
|
||||
global_buffer,
|
||||
)
|
||||
|
@ -2742,13 +2742,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
assert reduction_type == "welford_reduce"
|
||||
result_mean, result_m2, result_weight = result_var
|
||||
peer_mean = self.codegen_cooperative_reduction_peer_combine(
|
||||
result_mean, upcast_acc_dtype(src_dtype), default[0]
|
||||
result_mean,
|
||||
upcast_acc_dtype(src_dtype),
|
||||
default[0], # type: ignore[index]
|
||||
)
|
||||
peer_m2 = self.codegen_cooperative_reduction_peer_combine(
|
||||
result_m2, upcast_acc_dtype(src_dtype), default[1]
|
||||
result_m2,
|
||||
upcast_acc_dtype(src_dtype),
|
||||
default[1], # type: ignore[index]
|
||||
)
|
||||
peer_weight = self.codegen_cooperative_reduction_peer_combine(
|
||||
result_weight, upcast_acc_dtype(src_dtype), default[2]
|
||||
result_weight,
|
||||
upcast_acc_dtype(src_dtype),
|
||||
default[2], # type: ignore[index]
|
||||
)
|
||||
self.welford_reduce_final_reduction(
|
||||
self.post_loop_store,
|
||||
|
@ -1650,8 +1650,8 @@ def cudagraphify(
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
with dynamo_utils.preserve_rng_state():
|
||||
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
||||
return compiled_fn(new_inputs)
|
||||
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) # type: ignore[arg-type]
|
||||
return compiled_fn(new_inputs) # type: ignore[arg-type]
|
||||
|
||||
return run
|
||||
|
||||
|
@ -169,7 +169,7 @@ class VecNEON(VecISA):
|
||||
return "neon"
|
||||
return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -191,7 +191,7 @@ class VecSVE256(VecISA):
|
||||
return "neon"
|
||||
return "asimd"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -208,7 +208,7 @@ class VecAVX512(VecISA):
|
||||
def __str__(self) -> str:
|
||||
return "avx512"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -263,7 +263,7 @@ class VecAVX2(VecISA):
|
||||
def __str__(self) -> str:
|
||||
return "avx2"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -280,7 +280,7 @@ class VecZVECTOR(VecISA):
|
||||
def __str__(self) -> str:
|
||||
return "zvector"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -293,7 +293,7 @@ class VecVSX(VecISA):
|
||||
def __str__(self) -> str:
|
||||
return "vsx"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
class InvalidVecISA(VecISA):
|
||||
@ -308,7 +308,7 @@ class InvalidVecISA(VecISA):
|
||||
def __bool__(self) -> bool: # type: ignore[override]
|
||||
return False
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment]
|
||||
|
||||
|
||||
def x86_isa_checker() -> list[str]:
|
||||
|
@ -150,7 +150,7 @@ class ErasedTensor(torch.Tensor):
|
||||
self.owning_mod_ref = weakref.ref(mod)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
erased_tensors = [
|
||||
e
|
||||
for e in pytree.arg_tree_leaves(*args, **kwargs)
|
||||
|
@ -253,7 +253,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
||||
|
||||
def handle_views(node: torch.fx.Node):
|
||||
inp = node.args[0]
|
||||
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
||||
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment]
|
||||
node_to_view_op[node] = [
|
||||
*node_to_view_op[inp], # type: ignore[index]
|
||||
ViewOp(
|
||||
|
@ -2977,7 +2977,7 @@ class View(GenericView):
|
||||
return idx
|
||||
|
||||
@classmethod
|
||||
def create(cls, x, new_size): # type: ignore[no-untyped-def]
|
||||
def create(cls, x, new_size): # type: ignore[no-untyped-def, override]
|
||||
assert isinstance(new_size, (tuple, list))
|
||||
old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
|
||||
|
||||
@ -3305,7 +3305,7 @@ class SliceView(View):
|
||||
return start, end
|
||||
|
||||
@classmethod
|
||||
def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-untyped-def]
|
||||
def create(cls, x, dim, start, end, step=1, clamp=True): # type: ignore[no-untyped-def, override]
|
||||
step = sympy.expand(step)
|
||||
assert isinstance(step, sympy.Expr) or step > 0
|
||||
try:
|
||||
@ -3906,7 +3906,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
||||
def stride(self) -> list[Expr]:
|
||||
return self.real_layout().stride
|
||||
|
||||
@stride.setter
|
||||
@stride.setter # type: ignore[override]
|
||||
def stride(self, value: Never) -> None:
|
||||
pass # ignore setting of stride
|
||||
|
||||
|
@ -1558,7 +1558,7 @@ def register_replacement(
|
||||
normalize_args=normalize_args,
|
||||
)
|
||||
pattern.register(pass_dicts)
|
||||
return pattern.pattern
|
||||
return pattern.pattern # type: ignore[return-value]
|
||||
|
||||
|
||||
_serialized_patterns: OrderedSet[str] = OrderedSet()
|
||||
|
@ -160,7 +160,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
assert out._inference_mode_base is not None
|
||||
return out
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
unrecognized_types = [
|
||||
t
|
||||
for t in types
|
||||
@ -291,7 +291,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
return self.elem.to_dense()
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
def layout(self): # type: ignore[override]
|
||||
return self.elem.layout
|
||||
|
||||
def __bool__(self):
|
||||
|
@ -633,7 +633,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -833,7 +833,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -1034,7 +1034,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
|
@ -57,5 +57,5 @@ class LinearReLU(nnqd.Linear):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qlinear_relu):
|
||||
def from_reference(cls, ref_qlinear_relu): # type: ignore[override]
|
||||
return super().from_reference(ref_qlinear_relu[0])
|
||||
|
@ -47,7 +47,7 @@ class BNReLU2d(nnq.BatchNorm2d):
|
||||
return "QuantizedBNReLU2d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
# TODO: Add qat support for BNReLU2d
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
@ -96,7 +96,7 @@ class BNReLU3d(nnq.BatchNorm3d):
|
||||
return "QuantizedBNReLU3d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
# TODO: Add qat support for BNReLU3d
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
|
@ -68,7 +68,7 @@ class ConvAdd2d(nnq.Conv2d):
|
||||
return "QuantizedConvAdd2d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -137,7 +137,7 @@ class ConvAddReLU2d(nnq.Conv2d):
|
||||
return "QuantizedConvAddReLU2d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
|
@ -78,7 +78,7 @@ class ConvReLU1d(nnq.Conv1d):
|
||||
return "QuantizedConvReLU1d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
@ -159,7 +159,7 @@ class ConvReLU2d(nnq.Conv2d):
|
||||
return "QuantizedConvReLU2d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
@ -242,7 +242,7 @@ class ConvReLU3d(nnq.Conv3d):
|
||||
return "QuantizedConvReLU3d"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||
|
@ -175,7 +175,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -240,7 +240,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -305,7 +305,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return super().from_float(
|
||||
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
|
@ -147,7 +147,7 @@ class Linear(nnq.Linear):
|
||||
return qlinear
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qlinear):
|
||||
def from_reference(cls, ref_qlinear): # type: ignore[override]
|
||||
"""Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
|
||||
module
|
||||
Args:
|
||||
|
@ -83,7 +83,7 @@ class BatchNorm2d(_BatchNorm):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return _BatchNorm.from_float(
|
||||
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
@ -122,7 +122,7 @@ class BatchNorm3d(_BatchNorm):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
return _BatchNorm.from_float(
|
||||
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
||||
)
|
||||
|
@ -467,7 +467,7 @@ class Conv1d(_ConvNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
@ -597,7 +597,7 @@ class Conv2d(_ConvNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
@ -728,7 +728,7 @@ class Conv3d(_ConvNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
@ -794,7 +794,7 @@ class _ConvTransposeNd(_ConvNd):
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
||||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
@ -841,7 +841,7 @@ class _ConvTransposeNd(_ConvNd):
|
||||
return qconv
|
||||
|
||||
@staticmethod
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override]
|
||||
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
|
||||
Args:
|
||||
ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization
|
||||
@ -989,7 +989,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_reference(
|
||||
cls, ref_qconvt, output_scale, output_zero_point
|
||||
)
|
||||
@ -1112,7 +1112,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_reference(
|
||||
cls, ref_qconvt, output_scale, output_zero_point
|
||||
)
|
||||
@ -1237,7 +1237,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
|
||||
def from_reference(cls, ref_qconvt, output_scale, output_zero_point): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_reference(
|
||||
cls, ref_qconvt, output_scale, output_zero_point
|
||||
)
|
||||
|
@ -110,7 +110,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
||||
return "QuantizedConv1d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||
return "QuantizedConv2d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
|
||||
@ -236,7 +236,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||
return "QuantizedConv3d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
|
||||
@ -346,7 +346,7 @@ class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
|
||||
return "QuantizedConvTranspose1d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
|
||||
@ -427,7 +427,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
|
||||
return "QuantizedConvTranspose2d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
|
||||
@ -507,5 +507,5 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
|
||||
return "QuantizedConvTranspose3d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
def from_float(cls, float_conv, weight_qparams): # type: ignore[override]
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
@ -310,7 +310,7 @@ class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
|
||||
self.update_mask(name, data, **config)
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_mask(self, name, data, **kwargs):
|
||||
def update_mask(self, name, data, **kwargs): # type: ignore[override]
|
||||
pass
|
||||
|
||||
def _delete_data(self, name):
|
||||
|
@ -145,7 +145,7 @@ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
|
||||
print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}")
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
|
||||
scale = self.scale.detach()
|
||||
zero_point = (
|
||||
|
@ -671,23 +671,23 @@ class BackendPatternConfig:
|
||||
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
|
||||
conf.add_dtype_config(_get_dtype_config(d))
|
||||
conf.set_root_module(
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None))
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type]
|
||||
conf.set_reference_quantized_module(
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fused_module(
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fuser_method(
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_root_node_getter(
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_extra_inputs_getter(
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_num_tensor_args_to_observation_type(
|
||||
backend_pattern_config_dict.get(
|
||||
|
@ -218,7 +218,7 @@ class FakeQuantize(FakeQuantizeBase):
|
||||
self.is_per_channel = _is_per_channel(self.qscheme)
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
return self.activation_post_process.calculate_qparams()
|
||||
|
||||
def forward(self, X):
|
||||
@ -342,7 +342,7 @@ class FixedQParamsFakeQuantize(FakeQuantize):
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
return self.scale, self.zero_point
|
||||
|
||||
@torch.jit.export
|
||||
|
@ -364,7 +364,7 @@ def get_op_node_and_weight_eq_obs(
|
||||
maybe_equalization_node_name_to_config # type: ignore[assignment]
|
||||
)
|
||||
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
|
||||
weight_eq_obs = equalization_node_name_to_qconfig.get(
|
||||
weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr]
|
||||
op_node.name, None
|
||||
).weight()
|
||||
|
||||
@ -845,7 +845,7 @@ def convert_eq_obs(
|
||||
|
||||
# Erase the weight equalization observer node
|
||||
prev_node = weight_eq_obs_node.args[0]
|
||||
remove_node(model, weight_eq_obs_node, prev_node)
|
||||
remove_node(model, weight_eq_obs_node, prev_node) # type: ignore[arg-type]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected operation node to be 'call_module' or 'call_function"
|
||||
|
@ -279,7 +279,7 @@ class ModelReportObserver(ObserverBase):
|
||||
self.constant_channels = torch.tensor([], device=device)
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
raise Exception( # noqa: TRY002
|
||||
"calculate_qparams should not be called for ModelReportObserver"
|
||||
)
|
||||
|
@ -564,7 +564,7 @@ class MinMaxObserver(UniformQuantizationObserverBase):
|
||||
return x_orig
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
r"""Calculates the quantization parameters."""
|
||||
return self._calculate_qparams(self.min_val, self.max_val)
|
||||
|
||||
@ -787,7 +787,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
|
||||
return x_orig
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
return self._calculate_qparams(self.min_val, self.max_val)
|
||||
|
||||
def extra_repr(self):
|
||||
@ -1335,7 +1335,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||
return x_orig
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
is_uninitialized = self.min_val == float("inf") and self.max_val == float(
|
||||
"-inf"
|
||||
)
|
||||
@ -1448,7 +1448,7 @@ class FixedQParamsObserver(ObserverBase):
|
||||
return X
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
return self.scale, self.zero_point
|
||||
|
||||
|
||||
@ -1517,7 +1517,7 @@ class PlaceholderObserver(ObserverBase):
|
||||
return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}"
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
raise Exception( # noqa: TRY002
|
||||
"calculate_qparams should not be called for PlaceholderObserver"
|
||||
)
|
||||
@ -1544,7 +1544,7 @@ class RecordingObserver(ObserverBase):
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
raise Exception( # noqa: TRY002
|
||||
"calculate_qparams should not be called for RecordingObserver"
|
||||
)
|
||||
@ -1577,7 +1577,7 @@ class NoopObserver(ObserverBase):
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
raise Exception( # noqa: TRY002
|
||||
"calculate_qparams should not be called for NoopObserver"
|
||||
)
|
||||
@ -1604,7 +1604,7 @@ class ReuseInputObserver(ObserverBase):
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
def calculate_qparams(self): # type: ignore[override]
|
||||
raise Exception( # noqa: TRY002
|
||||
"calculate_qparams should not be called for ReuseInputObserver"
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ def _maybe_duplicate_dq(
|
||||
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
|
||||
):
|
||||
annotation = user.meta.get("quantization_annotation", None)
|
||||
if not _is_valid_annotation(annotation):
|
||||
if not _is_valid_annotation(annotation): # type: ignore[arg-type]
|
||||
return
|
||||
with gm.graph.inserting_after(dq_node):
|
||||
new_node = gm.graph.node_copy(dq_node)
|
||||
|
@ -138,7 +138,7 @@ def _make_grads(
|
||||
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
|
||||
|
||||
if not shape_matches:
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out) # type: ignore[redundant-cast]
|
||||
out_shape, grad_shape = _calculate_shape(
|
||||
out, first_grad, is_grads_batched
|
||||
)
|
||||
|
@ -815,7 +815,7 @@ class NestedIOFunction(Function):
|
||||
self._to_save_nested = args
|
||||
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
def saved_tensors(self): # type: ignore[override]
|
||||
r"""
|
||||
See :meth:`Function.saved_tensors`.
|
||||
"""
|
||||
|
@ -635,7 +635,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
return self.elem
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
if func == torch.ops.aten.view.default:
|
||||
# Fast handle aten.view as a lot of view related op goes to aten.view
|
||||
# eventually, this avoids pytree slowdown
|
||||
|
@ -184,7 +184,7 @@ class ShardedTensorBase(torch.Tensor):
|
||||
return sharded_tensor_base
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
raise RuntimeError(
|
||||
f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
|
||||
"but the there is no custom __torch_dispatch__ implementation for it."
|
||||
|
@ -425,14 +425,14 @@ def _write_files_from_queue(
|
||||
transforms,
|
||||
stream,
|
||||
tensor,
|
||||
write_item,
|
||||
write_item, # type: ignore[arg-type]
|
||||
storage_key,
|
||||
serialization_format,
|
||||
)
|
||||
)
|
||||
tensor_dict[write_item.index.fqn] = tensor
|
||||
metadata_dict[write_item.index.fqn] = {
|
||||
"saved_offsets": write_item.tensor_data.chunk.offsets
|
||||
tensor_dict[write_item.index.fqn] = tensor # type: ignore[attr-defined]
|
||||
metadata_dict[write_item.index.fqn] = { # type: ignore[attr-defined]
|
||||
"saved_offsets": write_item.tensor_data.chunk.offsets # type: ignore[attr-defined]
|
||||
}
|
||||
|
||||
if serialization_format == SerializationFormat.SAFETENSORS:
|
||||
|
@ -621,7 +621,7 @@ else:
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_names.append(dim_group.group_name)
|
||||
dim_group_names.append(dim_group.group_name) # type: ignore[union-attr]
|
||||
self._dim_group_names = dim_group_names
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
|
@ -61,7 +61,7 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
|
||||
self.averager = averager
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
def state(self): # type: ignore[override]
|
||||
return self.optim.state
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -346,7 +346,7 @@ class DTensor(torch.Tensor):
|
||||
@torch._disable_dynamo
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
return DTensor._op_dispatcher.dispatch(
|
||||
func,
|
||||
args,
|
||||
|
@ -84,7 +84,7 @@ class LocalShardsWrapper(torch.Tensor):
|
||||
|
||||
# necessary for ops dispatching from this subclass to its local shards
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# TODO: we shall continually extend this function to support more ops if needed
|
||||
|
@ -53,13 +53,13 @@ def _remove_effect_tokens_from_graph_helper(
|
||||
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
|
||||
|
||||
if func == torch.ops.higher_order.call_torchbind:
|
||||
custom_obj_meta = node.args[2].meta["val"]
|
||||
custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr]
|
||||
assert isinstance(custom_obj_meta, CustomObjArgument)
|
||||
if custom_obj_meta.fake_val:
|
||||
custom_obj = custom_obj_meta.fake_val
|
||||
elif node.args[2].name in inputs_to_lifted_custom_objs:
|
||||
elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr]
|
||||
custom_obj = ep.constants[
|
||||
inputs_to_lifted_custom_objs[node.args[2].name]
|
||||
inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr]
|
||||
]
|
||||
else:
|
||||
raise RuntimeError(f"Unable to find custom obj for node {node}")
|
||||
|
@ -199,11 +199,11 @@ class _StaticDim(Dim):
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
def min(self): # type: ignore[override]
|
||||
return self.value # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
def max(self): # type: ignore[override]
|
||||
return self.value # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@ -229,7 +229,7 @@ class _DerivedDim(Dim):
|
||||
self.fn = fn
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
def min(self): # type: ignore[override]
|
||||
# assume that self.fn is an increasing function
|
||||
# TODO(avik): use sympy value range analysis instead?
|
||||
from sympy import Integer
|
||||
@ -249,7 +249,7 @@ class _DerivedDim(Dim):
|
||||
return int(_min_symint)
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
def max(self): # type: ignore[override]
|
||||
# assume that self.fn is an increasing function
|
||||
# TODO(avik): use sympy value range analysis instead?
|
||||
from sympy import Integer
|
||||
|
@ -567,7 +567,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
|
||||
# TODO(zhxhchen17) Return the new graph_signature directly.
|
||||
fake_mode = detect_fake_mode(fake_args)
|
||||
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode
|
||||
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment]
|
||||
custom_triton_ops_decomposition_ctx = (
|
||||
contextlib.nullcontext
|
||||
if decompose_custom_triton_ops
|
||||
|
@ -143,7 +143,7 @@ class InterpreterModule(_SubmoduleBase, torch.nn.Module):
|
||||
super().__init__()
|
||||
self.graph = graph
|
||||
self._ty = ty
|
||||
self.graph.owning_module = self
|
||||
self.graph.owning_module = self # type: ignore[assignment]
|
||||
self._run_with_interpreter = RUN_WITH_INTERPRETER
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -296,7 +296,7 @@ class UnflattenedModule(torch.nn.Module):
|
||||
export_graph = deepcopy(export_module.graph)
|
||||
self.graph_signature = deepcopy(export_module.graph_signature)
|
||||
self.graph = torch.fx.Graph()
|
||||
self.graph.owning_module = self
|
||||
self.graph.owning_module = self # type: ignore[assignment]
|
||||
self.module_call_graph = deepcopy(export_module.module_call_graph)
|
||||
self.flat_args_adapter = flat_args_adapter
|
||||
|
||||
|
@ -126,7 +126,7 @@ class MetaAttribute(MetaProxy):
|
||||
self._node = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
def node(self): # type: ignore[override]
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
|
@ -471,7 +471,7 @@ def optimize_for_inference(
|
||||
if not use_mkl_heuristic(graph):
|
||||
for node in graph.start_nodes + graph.end_nodes:
|
||||
prv = node.args[0]
|
||||
node.replace_all_uses_with(prv)
|
||||
node.replace_all_uses_with(prv) # type: ignore[arg-type]
|
||||
fx_graph.erase_node(node)
|
||||
reset_modules(graph.nodes, modules, old_modules)
|
||||
|
||||
|
@ -1011,7 +1011,7 @@ class _SymNodeDict:
|
||||
) -> _PySymProxyType:
|
||||
# dict.get()'s annotation doesn't accept `None` when the value type
|
||||
# isn't Optional.
|
||||
return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type]
|
||||
return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type, return-value]
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
@ -1415,7 +1415,7 @@ def _make_node_magic(method, func):
|
||||
out,
|
||||
self.shape_env,
|
||||
pytype,
|
||||
out_hint,
|
||||
out_hint, # type: ignore[arg-type]
|
||||
fx_node=fx_node,
|
||||
optimized_summation=optimized_summation, # see Note [optimized_summation]
|
||||
)
|
||||
|
@ -354,7 +354,7 @@ class Dispatcher:
|
||||
self._cache = {}
|
||||
|
||||
@property
|
||||
def __doc__(self):
|
||||
def __doc__(self): # type: ignore[override]
|
||||
docs = [f"Multiply dispatched method: {self.name}"]
|
||||
|
||||
if self.doc:
|
||||
|
@ -794,7 +794,7 @@ def _sparse_csr_segment_reduction_helper(
|
||||
0,
|
||||
)
|
||||
new_nnz = new_crow_indices[-1]
|
||||
new_col_indices = col_indices.new_zeros(new_nnz)
|
||||
new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload]
|
||||
new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
|
||||
new_shape = [mask_input.size(0), 1]
|
||||
else:
|
||||
|
@ -304,7 +304,7 @@ class MaskedTensor(torch.Tensor):
|
||||
return MaskedTensor(fn(data), mask)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs): # type: ignore[override]
|
||||
func = func.overloadpacket
|
||||
|
||||
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
|
||||
@ -355,5 +355,5 @@ class MaskedTensor(torch.Tensor):
|
||||
|
||||
# Update later to support more sparse layouts
|
||||
@property
|
||||
def is_sparse(self):
|
||||
def is_sparse(self): # type: ignore[override]
|
||||
return self.is_sparse_coo() or self.is_sparse_csr()
|
||||
|
@ -319,7 +319,7 @@ class NestedTensor(torch.Tensor):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
|
||||
# If you're wondering why there's a nested tensor with one of its
|
||||
# size = -1, see note: [NJT outer_size in AOTDispatcher]
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
@ -184,7 +184,7 @@ def replicate(
|
||||
# so setattr them as non-parameter attributes
|
||||
setattr(replica, key, param_copy)
|
||||
# expose the parameter for DDP
|
||||
replica._former_parameters[key] = param_copy
|
||||
replica._former_parameters[key] = param_copy # type: ignore[operator, index]
|
||||
for key, buf in module._buffers.items(): # type: ignore[assignment]
|
||||
if buf is None:
|
||||
for j in range(num_replicas):
|
||||
|
@ -152,23 +152,23 @@ class ExpandedWeight(torch.Tensor):
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
def dtype(self): # type: ignore[override]
|
||||
return self.orig_weight.dtype
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self): # type: ignore[override]
|
||||
return self.orig_weight.data
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
def shape(self): # type: ignore[override]
|
||||
return self.orig_weight.shape
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self): # type: ignore[override]
|
||||
return self.orig_weight.device
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
def is_cuda(self): # type: ignore[override]
|
||||
return self.orig_weight.is_cuda
|
||||
|
||||
def data_ptr(self):
|
||||
|
@ -417,7 +417,7 @@ class Identity(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name):
|
||||
def apply(cls, module, name): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
@ -472,7 +472,7 @@ class RandomUnstructured(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name, amount):
|
||||
def apply(cls, module, name, amount): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
@ -531,7 +531,7 @@ class L1Unstructured(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name, amount, importance_scores=None):
|
||||
def apply(cls, module, name, amount, importance_scores=None): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
@ -642,7 +642,7 @@ class RandomStructured(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name, amount, dim=-1):
|
||||
def apply(cls, module, name, amount, dim=-1): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
@ -758,7 +758,7 @@ class LnStructured(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name, amount, n, dim, importance_scores=None):
|
||||
def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
@ -805,7 +805,7 @@ class CustomFromMask(BasePruningMethod):
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def apply(cls, module, name, mask):
|
||||
def apply(cls, module, name, mask): # type: ignore[override]
|
||||
r"""Add pruning on the fly and reparametrization of a tensor.
|
||||
|
||||
Adds the forward pre-hook that enables pruning on the fly and
|
||||
|
@ -7,7 +7,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from onnxscript.onnx_opset import opset20 as op20, opset21 as op21, opset23 as op23
|
||||
from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
|
||||
opset20 as op20,
|
||||
opset21 as op21,
|
||||
opset23 as op23,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||
|
@ -299,7 +299,7 @@ class LBFGS(Optimizer):
|
||||
return loss, flat_grad
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure):
|
||||
def step(self, closure): # type: ignore[override]
|
||||
"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
|
@ -825,7 +825,7 @@ def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener:
|
||||
container = _open_zipfile_writer_file
|
||||
else:
|
||||
container = _open_zipfile_writer_buffer
|
||||
return container(name_or_buffer)
|
||||
return container(name_or_buffer) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _is_compressed_file(f) -> bool:
|
||||
|
@ -197,10 +197,10 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment]
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override]
|
||||
if func._overloadpacket not in cls.SPARSE_DISPATCH:
|
||||
raise NotImplementedError(
|
||||
f"{cls.__name__} only supports a specific set of operations, "
|
||||
|
@ -1523,7 +1523,7 @@ class _LegacyStorageMeta(type):
|
||||
|
||||
class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
|
||||
@classmethod
|
||||
def _new_shared(cls, size):
|
||||
def _new_shared(cls, size): # type: ignore[override]
|
||||
"""Create a new storage in shared memory with the same data type."""
|
||||
untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
|
||||
return cls(wrap_storage=untyped_storage)
|
||||
|
@ -1180,7 +1180,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||
return run_subtests(self, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _run(cls, rank, test_name, file_name, pipe, **kwargs):
|
||||
def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override]
|
||||
self = cls(test_name)
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
|
@ -3900,7 +3900,7 @@ class TestCase(expecttest.TestCase):
|
||||
((0, 0), [(1, 2)], [()]),
|
||||
]:
|
||||
for blocksize in blocksizes:
|
||||
for densesize in densesizes:
|
||||
for densesize in densesizes: # type: ignore[attr-defined]
|
||||
if layout == torch.strided:
|
||||
indices = () # type: ignore[assignment]
|
||||
values = torch.empty((basesize + densesize), device=device, dtype=dtype)
|
||||
|
Reference in New Issue
Block a user