Update mypy to 1.4.1 (#91983)

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
This commit is contained in:
Nikita Shulga
2023-07-13 16:30:33 +00:00
committed by PyTorch MergeBot
parent f73757d551
commit 634659e262
69 changed files with 177 additions and 171 deletions

View File

@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:
mypy==0.960
mypy==1.4.1
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 0.960
#Pinned versions: 1.4.1
#test that import: test_typing.py, test_type_hints.py
networkx==2.8.8

View File

@ -152,7 +152,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'numpy==1.24.3',
'expecttest==0.1.3',
'mypy==0.960',
'mypy==1.4.1',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',

View File

@ -49,7 +49,7 @@ try:
# use faster C loader if available
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
def write(filename, s):

View File

@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def):
wall_time=step, step=step, graph_def=graph_def.SerializeToString())
@cli.command("tensorboard-graphs")
@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined]
@click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False),
multiple=True)
@click.option("--tf-dir", type=click.Path(exists=True))
@ -129,7 +129,7 @@ def tensorboard_graphs(c2_netdef, tf_dir):
log.info("Wrote %s graphs to logdir %s", len(events), tf_dir)
@cli.command("tensorboard-events")
@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined]
@click.option("--c2-dir", type=click.Path(exists=True, file_okay=False),
help="Root directory of the Caffe2 run")
@click.option("--tf-dir", type=click.Path(writable=True),
@ -209,4 +209,4 @@ def tensorboard_events(c2_dir, tf_dir):
if __name__ == "__main__":
cli()
cli() # type: ignore[misc]

View File

@ -21,7 +21,7 @@ class TestFuture(TestCase):
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
# Set exception
f.set_exception(value_error)
# Exception should throw on wait
@ -29,7 +29,7 @@ class TestFuture(TestCase):
f.wait()
# Exception should also throw on value
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value()
@ -37,7 +37,7 @@ class TestFuture(TestCase):
def cb(fut):
fut.value()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
@ -54,7 +54,7 @@ class TestFuture(TestCase):
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=wait_future, args=(f, ))
t.start()
f.set_exception(value_error)
@ -68,7 +68,7 @@ class TestFuture(TestCase):
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
fut.wait()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=then_future, args=(f, ))
t.start()
f.set_exception(value_error)

View File

@ -29,7 +29,7 @@ from itertools import product, combinations, permutations
from functools import partial
from torch import multiprocessing as mp
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
@ -8469,7 +8469,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
except RuntimeError:
pass
with mp.Pool(1) as pool:
out: list = pool.map(method, [arg])
out = pool.map(method, [arg])
self.assertTrue(out[0])
def _test_multinomial_invalid_probs(probs):

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3
import time
from package.oss.cov_json import get_json_report
from package.oss.init import initialization
from package.tool.summarize_jsons import summarize_jsons
from package.util.setting import TestPlatform
from package.util.utils import print_time
from package.oss.cov_json import get_json_report # type: ignore[import]
from package.oss.init import initialization # type: ignore[import]
from package.tool.summarize_jsons import summarize_jsons # type: ignore[import]
from package.util.setting import TestPlatform # type: ignore[import]
from package.util.utils import print_time # type: ignore[import]
def report_coverage() -> None:

View File

@ -45,7 +45,7 @@ def transform_file_name(
return file_path[file_path.find(folder) :]
# remove pytorch base folder path
if platform == TestPlatform.OSS:
from package.oss.utils import get_pytorch_folder
from package.oss.utils import get_pytorch_folder # type: ignore[import]
pytorch_foler = get_pytorch_folder()
assert file_path.startswith(pytorch_foler)

View File

@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str:
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
if platform == TestPlatform.OSS:
from package.oss.utils import detect_compiler_type # type: ignore[misc]
from package.oss.utils import ( # type: ignore[assignment, import, misc]
detect_compiler_type,
)
cov_type = detect_compiler_type() # type: ignore[call-arg]
else:
@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType:
cov_type = detect_compiler_type()
check_compiler_type(cov_type)
return cov_type
return cov_type # type: ignore[no-any-return]
def get_test_name_from_whole_path(path: str) -> str:

View File

@ -20,7 +20,7 @@ from yaml.nodes import MappingNode
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]
from yaml import Loader # type: ignore[assignment, misc]
H_NAME = "spv.h"
CPP_NAME = "spv.cpp"

View File

@ -16,7 +16,7 @@ from yaml import dump, load
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
class LintSeverity(str, Enum):

View File

@ -11,7 +11,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {

View File

@ -10,7 +10,7 @@ try:
# use faster C loader if available
from yaml import CSafeLoader as YamlLoader
except ImportError:
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
TAGS_PATH = "aten/src/ATen/native/tags.yaml"

View File

@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
def _get_upgraders_map_size() -> _int: ...
def _get_upgraders_entry_map() -> Dict[str, str]: ...
def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...

View File

@ -417,7 +417,7 @@ def sym_int(a):
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type]
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
return py_int(a) # type: ignore[operator]
def sym_max(a, b):
@ -1316,7 +1316,7 @@ if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
# Fixup segment_reduce visibility
_segment_reduce = segment_reduce
del segment_reduce

View File

@ -10,20 +10,20 @@ class _Union:
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
def __post_init__(self):
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc]
@property
def value(self):
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val is not None
return val
@property
def type(self):
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val_type is not None
return val_type

View File

@ -9,7 +9,7 @@ import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import sympy
@ -749,7 +749,7 @@ class GraphModuleDeserializer:
self.module = torch.nn.Module()
@contextmanager
def save_graph_module(self) -> None:
def save_graph_module(self) -> Iterator[None]:
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
@ -773,7 +773,7 @@ class GraphModuleDeserializer:
if vr := self.symbol_name_to_range.get(val.expr_str):
symbolic_shapes._constrain_symbol_range(
self.shape_env, sym, vr.lower, vr.upper
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
)
return self.shape_env.create_symintnode(sym, hint=val.hint)
@ -855,6 +855,7 @@ class GraphModuleDeserializer:
output_node.meta["val"] = tuple(
arg.meta["val"] for arg in output_node.args[0]
)
return output_node
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
@ -1050,7 +1051,7 @@ class GraphModuleDeserializer:
self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
ret = {}
ret: Dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace

View File

@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]:
"""Getting upgraders entry map and operator version map and merge them into one dict."""
upgraders = torch._C._get_upgraders_entry_map()
op_version_map = torch._C._get_operator_version_map()
output = defaultdict(tuple)
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
for opname, entry_list in op_version_map.items():
if not entry_list:
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")

View File

@ -27,7 +27,7 @@ def load_tensor_reader(loc):
def register_debug_prims():
@custom_op("debugprims::load_tensor")
def load_tensor(
def load_tensor( # type: ignore[empty-body]
name: str,
size: Sequence[int],
stride: Sequence[int],

View File

@ -10,7 +10,7 @@ from torch._prims_common import (
import torch._prims_common as utils
from torch.utils._pytree import tree_flatten, tree_unflatten
from typing import Callable, Sequence, Tuple, NamedTuple, overload
from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload
import inspect
from functools import wraps
import warnings
@ -97,7 +97,7 @@ class elementwise_type_promotion_wrapper:
self,
*,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
type_promoting_args: Sequence[str] = None,
type_promoting_args: Optional[Sequence[str]] = None,
):
self.type_promoting_arg_names = type_promoting_args
self.type_promotion_kind = type_promotion_kind

View File

@ -856,7 +856,7 @@ def prepare_n_shadows_model(
create_n_transformed_and_logged_copies_of_subgraph(
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
custom_prepare_fn, custom_prepare_kwargs
custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
)
return mt

View File

@ -443,7 +443,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
example_inputs: Any,
last_added_shadow_node_list: List[Optional[Node]],
custom_prepare_fn: Optional[Callable] = None,
custom_prepare_kwargs: Dict[str, Any] = None,
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Given a subgraph in `mt` and a subgraph candidate idx, inserts the
@ -575,7 +575,7 @@ def create_n_transformed_and_logged_copies_of_subgraph(
qconfig_mappings: List[QConfigMapping],
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
custom_prepare_fn: Optional[Callable] = None,
custom_prepare_kwargs: Dict[str, Any] = None,
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Given a model `mt` and a subgraph_idx, creates the needed copies
@ -756,7 +756,7 @@ def create_add_loggers_graph(
create_n_transformed_and_logged_copies_of_subgraph(
model, cur_subgraph_idx, match_name, maybe_subgraph,
[qconfig_mapping], [node_name_to_qconfig],
None, None
None, None # type: ignore[arg-type]
)
# find the created shadow module and record it so we
# can find it easily in step 2

View File

@ -78,7 +78,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
res.append(param_value)
return res
else:
assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
res = []
for weight_value in mod._all_weight_values:
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])

View File

@ -3,7 +3,7 @@ import torch
from dlrm_s_pytorch import unpack_batch # type: ignore[import]
import numpy as np # type: ignore[import]
import sklearn # type: ignore[import]
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model # type: ignore[import]
import pandas as pd # type: ignore[import]
import argparse

View File

@ -458,8 +458,10 @@ def _update_special_qspecs_after_replacement(
if isinstance(edge_or_node, Node):
_node = edge_or_node
return original_to_replacement_node.get(_node, _node)
elif isinstance(edge_or_node, Tuple[Node, Node]):
src, dest = edge_or_node
# TODO: It's really should be
# isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node)
elif isinstance(edge_or_node, Tuple[Node, Node]): # type: ignore[arg-type]
src, dest = edge_or_node # type: ignore[misc]
return (
original_to_replacement_node.get(src, src),
original_to_replacement_node.get(dest, dest),

View File

@ -58,13 +58,13 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
for conv_op, add_op, relu_op in conv_add_relu_options:
if add_op is None:
# Append Conv ReLU
supported_operators["conv2d"].append([conv_op, relu_op])
supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item]
elif relu_op is None:
# Append Conv Add
supported_operators["conv2d"].append([conv_op, add_op])
supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item]
else:
# Append Conv Add ReLU
supported_operators["conv2d"].append([conv_op, add_op, relu_op])
supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item]
return copy.deepcopy(supported_operators)
@ -222,17 +222,17 @@ class X86InductorQuantizer(Quantizer):
"""
conv_gemm_node_idx = None
extra_input_node_idx = None
if (binary_node.args[0].op == "call_function") and (
if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr]
binary_node.args[0] == conv_gemm_node
):
conv_gemm_node_idx = 0
extra_input_node_idx = 1
elif (binary_node.args[1].op == "call_function") and (
elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr]
binary_node.args[1] == conv_gemm_node
):
conv_gemm_node_idx = 1
extra_input_node_idx = 0
extra_input_node = binary_node.args[extra_input_node_idx]
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
assert isinstance(extra_input_node, Node)
return conv_gemm_node_idx, extra_input_node_idx

View File

@ -119,9 +119,9 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_remove_tensor_overload_for_qdq_ops(model)
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
pattern = _get_aten_graph_module(pattern, example_inputs)
_remove_tensor_overload_for_qdq_ops(pattern)
replacement = _get_aten_graph_module(replacement, example_inputs)
_remove_tensor_overload_for_qdq_ops(replacement)
pattern = _get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
_remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = _get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
_remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
matches = replace_pattern(model, pattern, replacement)
return model

View File

@ -44,7 +44,7 @@ class APoTQuantizer():
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment]
return result
@ -83,7 +83,7 @@ class APoTQuantizer():
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
levels_lst = list(self.quantization_levels)
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg]
return result

View File

@ -85,9 +85,9 @@ def _find_matches(
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
standalone_module_names: Optional[List[str]] = None,
standalone_module_classes: Optional[List[Type]] = None,
custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.

View File

@ -18,7 +18,7 @@ from torch.ao.quantization.utils import (
)
from abc import ABC
from typing import Callable, Dict, List, Type
from typing import Callable, Dict, List, Type, Optional
__all__ = [
"QuantizeHandler",
@ -52,7 +52,7 @@ class QuantizeHandler(ABC):
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None,
root_node_getter: Optional[Callable] = None,
is_custom_module=False,
is_standalone_module=False):
""" Records pattern information in __init__, which will be used
@ -113,7 +113,7 @@ def _get_quantize_handler_cls(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
root_node_getter: Optional[Callable] = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \

View File

@ -455,7 +455,7 @@ def _profile_to_snapshot(profile):
device = to_device(tensor_key.device)
addr = tensor_key.storage.ptr
seg = snapshot['segments'][device]
seg = snapshot['segments'][device] # type: ignore[index]
if seg['address'] is None or seg['address'] > addr:
seg['address'] = addr
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
@ -465,12 +465,12 @@ def _profile_to_snapshot(profile):
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
if during_trace:
snapshot['device_traces'][device].append(r)
snapshot['device_traces'][device].append(r) # type: ignore[index]
return r
def free(alloc, device):
for e in ('free_requested', 'free_completed'):
snapshot['device_traces'][device].append({'action': e,
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
'addr': alloc['addr'],
'size': alloc['size'],
'stream': 0,
@ -499,7 +499,7 @@ def _profile_to_snapshot(profile):
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
for (tensor_key, version), event in kv_to_elem.items()]
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
seg = snapshot['segments'][device]
seg = snapshot['segments'][device] # type: ignore[index]
last_addr = seg['address']
for _, addr, size, frames in blocks:
if last_addr < addr:
@ -510,8 +510,8 @@ def _profile_to_snapshot(profile):
if last_addr < seg['total_size']:
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]
for seg in snapshot['segments']:
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
seg['total_size'] -= seg['address']
if not seg['blocks']:
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})

View File

@ -1,5 +1,5 @@
import torch
from typing import cast, Iterable, List, Union
from typing import Iterable, List, Union
from . import _lazy_init, _lazy_call, device_count, current_device
from .. import Tensor
@ -56,7 +56,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu
device = torch.device('cuda', device)
def cb():
idx = cast(torch.device, device).index
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]

View File

@ -732,7 +732,7 @@ class ShardedTensor(ShardedTensorBase):
local_tensor: torch.Tensor,
sharding_spec: shard_spec.ShardingSpec,
*global_size: Sequence[int],
process_group: dist.ProcessGroup = None,
process_group: Optional[dist.ProcessGroup] = None,
init_rrefs=False,
) -> "ShardedTensor":
"""

View File

@ -389,7 +389,7 @@ def _compile(
# can trace operations applied to them.
def stateless_func(func, params, buffers, named_states, args, kwargs):
with stateless._reparametrize_module(
cast(nn.Module, mod), {**params, **buffers}
mod, {**params, **buffers}
), _rematerialize_optimizer(
opt, named_states, params
) if opt else nullcontext():

View File

@ -133,7 +133,7 @@ def gen_model_param_in_submesh(model: nn.Module, sub_mesh: DeviceMesh) -> nn.Mod
)
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module:
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: # type: ignore[empty-body]
"""
checkpoint save/load models with DTensor parameters
"""

View File

@ -40,7 +40,7 @@ def register_op_strategy(op):
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.

View File

@ -160,7 +160,7 @@ def all_gather(
ret_list.append(sp.payload)
if len(exception_list) > 0:
raise RuntimeError(
raise RuntimeError( # type: ignore[misc]
error_msg, exception_list) from exception_list[0]
return ret_list
else:
@ -168,7 +168,7 @@ def all_gather(
raise RuntimeError(
f"all_gather failed with exception {sync_obj.exception}",
) from sync_obj.exception
return [sync_obj.payload]
return [sync_obj.payload] # type: ignore[list-item]
# Note: use Any for typing for now so users can pass in

View File

@ -1433,7 +1433,7 @@ def _register_post_backward_reshard_only_hooks(
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
)
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined]
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
@no_type_check
@ -1468,11 +1468,11 @@ def _wait_for_computation_stream(
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
unshard_stream.wait_stream(computation_stream)
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream)
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
def _reset_flat_param_grad_info_if_needed(

View File

@ -167,7 +167,7 @@ class _ExecOrderTracer:
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
) -> torch.fx.Proxy:
"""
Overrides ``create_proxy`` to save execution information to

View File

@ -96,7 +96,7 @@ def _auto_wrap(
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
def _check_nested_wrapping(root_module: nn.Module):

View File

@ -358,7 +358,7 @@ class _RemoteModule(nn.Module):
def bfloat16(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.bfloat16.__name__)
def to(self, *args, **kwargs) -> T: # type: ignore[return]
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
_raise_not_supported(self.to.__name__)
def register_backward_hook( # type: ignore[return]
@ -377,7 +377,7 @@ class _RemoteModule(nn.Module):
) -> RemovableHandle:
_raise_not_supported(self.register_forward_pre_hook.__name__)
def register_forward_hook( # type: ignore[return]
def register_forward_hook( # type: ignore[return, override]
self,
hook: Union[
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],

View File

@ -162,7 +162,7 @@ class _NamedOptimizer(optim.Optimizer):
return self._optimizer.step(closure=closure)
@property
def state(self) -> Mapping[torch.Tensor, Any]:
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
return self._optimizer.state
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:

View File

@ -49,7 +49,7 @@ def input_reshard(
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
nonlocal cx
cx.__exit__() # type: ignore[name-defined]
cx.__exit__() # type: ignore[name-defined, union-attr]
if input_reshard_dim is None:
return module

View File

@ -23,7 +23,7 @@ class FoldedGraphModule(torch.fx.GraphModule):
root: torch.nn.Module,
graph: torch.fx.Graph,
const_subgraph: Optional[torch.fx.Graph] = None,
fx_const_folded_attrs_name: str = None,
fx_const_folded_attrs_name: Optional[str] = None,
device_for_folded_attrs: str = "cuda",
):
super().__init__(root, graph)

View File

@ -1,4 +1,4 @@
# type: ignore[attr-defined]
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403

View File

@ -202,7 +202,7 @@ def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
"""
@ -222,7 +222,7 @@ def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:

View File

@ -103,7 +103,7 @@ class Sequential(Module):
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)

View File

@ -51,7 +51,7 @@ class _ConvNd(Module):
'out_channels', 'kernel_size']
__annotations__ = {'bias': Optional[torch.Tensor]}
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
...
in_channels: int

View File

@ -340,7 +340,7 @@ class ProtobufExportOutputSerializer:
) -> None:
import onnx
if not isinstance(export_output.model_proto, onnx.ModelProto):
if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
raise ValueError("export_output.ModelProto is not an onnx.ModelProto")
destination.write(export_output.model_proto.SerializeToString())
@ -348,7 +348,7 @@ class ProtobufExportOutputSerializer:
class ExportOutput:
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""
_model_proto: Final[onnx.ModelProto]
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined]
_input_adapter: Final[io_adapter.InputAdapter]
_output_adapter: Final[io_adapter.OutputAdapter]
_diagnostic_context: Final[infra.DiagnosticContext]
@ -356,7 +356,7 @@ class ExportOutput:
@_beartype.beartype
def __init__(
self,
model_proto: onnx.ModelProto,
model_proto: onnx.ModelProto, # type: ignore[name-defined]
input_adapter: io_adapter.InputAdapter,
output_adapter: io_adapter.OutputAdapter,
diagnostic_context: infra.DiagnosticContext,
@ -367,7 +367,7 @@ class ExportOutput:
self._diagnostic_context = diagnostic_context
@property
def model_proto(self) -> onnx.ModelProto:
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an ``onnx.ModelProto``."""
return self._model_proto

View File

@ -193,7 +193,7 @@ class DynamoExport(exporter.FXGraphExtractor):
wrapped_model,
*model_args,
tracing_mode=fx_mode,
fake_mode=fake_mode,
fake_mode=fake_mode, # type: ignore[arg-type]
**model_kwargs,
)
del graph_guard # Unused

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> onnx.TensorProto:
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
@ -38,13 +38,13 @@ def _create_tensor_proto_with_external_data(
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
tensor_proto = onnx.TensorProto()
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
tensor.dtype
).onnx_type()
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
@ -86,7 +86,7 @@ def save_model_with_external_data(
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[Union[str, io.BytesIO], ...],
onnx_model: onnx.ModelProto,
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
@ -121,7 +121,7 @@ def save_model_with_external_data(
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers = onnx.ModelProto() # type: ignore[attr-defined]
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
@ -159,4 +159,4 @@ def save_model_with_external_data(
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) # type: ignore[attr-defined]

View File

@ -62,7 +62,7 @@ def export_as_test_case(
shutil.rmtree(data_set_dir)
os.makedirs(data_set_dir)
proto = onnx.load_model_from_string(model_bytes)
proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
@ -112,12 +112,12 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
inputs = {}
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
for input_file in input_files:
tensor = onnx.load_tensor(input_file)
tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
inputs[tensor.name] = numpy_helper.to_array(tensor)
outputs = {}
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
for output_file in output_files:
tensor = onnx.load_tensor(output_file)
tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
outputs[tensor.name] = numpy_helper.to_array(tensor)
return model_bytes, inputs, outputs
@ -227,7 +227,7 @@ def _add_onnxscript_fn(
# size > 2GB, and if it for some reason did not, the model would fail on
# serialization anyway in terms of the protobuf limitation. So we don't
# need to worry about > 2GB model getting here.
model_proto = onnx.load_model_from_string(model_bytes)
model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
# Iterate graph nodes to insert only the included custom
# function_proto into model_proto

View File

@ -206,14 +206,14 @@ def _ort_session(
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
try:
import onnx
from onnx import reference as onnx_reference
from onnx import reference as onnx_reference # type: ignore[attr-defined]
except ImportError:
raise ImportError("onnx >= 1.13 is required for reference evaluator.")
proto = (
onnx.load(model)
onnx.load(model) # type: ignore[attr-defined]
if isinstance(model, str)
else onnx.load_model_from_string(model.getvalue())
else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined]
)
onnx_session = onnx_reference.ReferenceEvaluator(proto)
return onnx_session

View File

@ -225,9 +225,9 @@ class Optimizer:
"""
OptimizerPreHook: TypeAlias = Callable[
[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]] # type: ignore[valid-type]
[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]] # type: ignore[misc, valid-type]
]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[valid-type]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc, valid-type]
_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]

View File

@ -603,7 +603,7 @@ def input_dtypes(event: _ProfilerEvent):
def report_all_anti_patterns(prof,
should_benchmark: bool = False,
print_enable: bool = True,
json_report_dir: str = None):
json_report_dir: Optional[str] = None):
report_dict: Dict = {}
anti_patterns = [
ExtraCUDACopyPattern(prof, should_benchmark),

View File

@ -202,7 +202,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
self.compressed_tensor = compressed_tensor
self.transposed = transposed
def __repr__(self) -> str:
def __repr__(self) -> str: # type: ignore[override]
"""Return string representation of SparseSemiStructuredTensor
Returns:

View File

@ -3,7 +3,7 @@ import io
import torch
from ._utils import _type, _cuda, _hpu
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict
from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union
import copy
import collections
from functools import lru_cache
@ -27,51 +27,51 @@ class _StorageBase:
device: torch.device
def __init__(self, *args, **kwargs): ... # noqa: E704
def __len__(self) -> int: ... # noqa: E704
def __len__(self) -> int: ... # type: ignore[empty-body] # noqa: E704
def __getitem__(self, idx): ... # noqa: E704
def __setitem__(self, *args, **kwargs): ... # noqa: E704
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
def new(self) -> T: ... # noqa: E704
def nbytes(self) -> int: ... # noqa: E704
def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ... # type: ignore[empty-body] # noqa: E704
def new(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def nbytes(self) -> int: ... # type: ignore[empty-body] # noqa: E704
def size(self) -> int:
return self.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def element_size(self) -> int: ... # noqa: E704
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704
def get_device(self) -> int:
return self.device.index
def data_ptr(self) -> int: ... # noqa: E704
def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704
# Defined in torch/csrc/generic/StorageSharing.cpp
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
def from_buffer(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_decref(self) -> T: ... # noqa: E704
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _write_file(self, *args, **kwargs): ... # noqa: E704
def resize_(self, size: int): ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
def is_shared(self) -> bool: ... # noqa: E704
def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
@ -80,9 +80,9 @@ class _StorageBase:
@property
def is_hpu(self): ... # noqa: E704
@classmethod
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
def from_file(cls, filename, shared, nbytes) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
@classmethod
def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
def _expired(cls, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _byteswap(self, *args, **kwargs): ... # noqa: E704
def __str__(self):
@ -712,12 +712,12 @@ class TypedStorage:
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
return tmp_tensor[idx_wrapped].item()
def copy_(self, source: T, non_blocking: bool = None):
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
_warn_typed_storage_removal()
if isinstance(source, TypedStorage):
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
else:
self._untyped_storage.copy_(source, non_blocking)
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
return self
def nbytes(self):
@ -728,7 +728,7 @@ class TypedStorage:
def _nbytes(self):
return self._untyped_storage.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]:
_warn_typed_storage_removal()
if dtype is None:
legacy_class = self._get_legacy_storage_class()
@ -741,14 +741,14 @@ class TypedStorage:
else:
return self._untyped_storage.type(dtype, non_blocking)
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create HPU storage with quantized dtype")

View File

@ -43,37 +43,37 @@ class Storage:
dtype: torch.dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo) -> 'Storage':
def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body]
...
def _new_shared(self, int) -> 'Storage':
def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body]
...
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
...
def element_size(self) -> int:
def element_size(self) -> int: # type: ignore[empty-body]
...
def is_shared(self) -> bool:
def is_shared(self) -> bool: # type: ignore[empty-body]
...
def share_memory_(self) -> 'Storage':
def share_memory_(self) -> 'Storage': # type: ignore[empty-body]
...
def nbytes(self) -> int:
def nbytes(self) -> int: # type: ignore[empty-body]
...
def cpu(self) -> 'Storage':
def cpu(self) -> 'Storage': # type: ignore[empty-body]
...
def data_ptr(self) -> int:
def data_ptr(self) -> int: # type: ignore[empty-body]
...
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body]
...
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
...
...

View File

@ -100,7 +100,7 @@ def report_compile_source_on_error():
# specifically _PyCode_InitAddressRange, reveals that
# this iterator is initialized from co_linetable and
# co_firstfileno. So copy these we must!
code = code.replace(
code = code.replace( # type: ignore[call-arg]
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
)

View File

@ -1066,7 +1066,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
self._data_queue = self._worker_result_queue # type: ignore[assignment]
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked

View File

@ -415,7 +415,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
@functional_datapipe('trace_as_dataframe')
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
source_datapipe = None
# TODO(VitalyFedyunin): Must implement all special functions of datapipes

View File

@ -376,7 +376,7 @@ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataP
self._datapipe_iter = iter(self._datapipe)
return self
def __next__(self) -> T_co:
def __next__(self) -> T_co: # type: ignore[type-var]
assert self._datapipe_iter is not None
return next(self._datapipe_iter)

View File

@ -421,7 +421,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
self.main_datapipe_exhausted = False
self._child_stop: List[bool] = [True for _ in range(num_instances)]
def _find_next(self, instance_id: int) -> T_co:
def _find_next(self, instance_id: int) -> T_co: # type: ignore[type-var]
while True:
if self.main_datapipe_exhausted or self._child_stop[instance_id]:
raise StopIteration

View File

@ -40,7 +40,7 @@ class ConcaterMapDataPipe(MapDataPipe):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes # type: ignore[assignment]
def __getitem__(self, index) -> T_co:
def __getitem__(self, index) -> T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
if index - offset < len(dp):

View File

@ -418,5 +418,5 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]

View File

@ -236,7 +236,7 @@ class FlopCounterMode(TorchDispatchMode):
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
depth: int = 2,
display: bool = True,
custom_mapping: Dict[Any, Any] = None):
custom_mapping: Optional[Dict[Any, Any]] = None):
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
self.depth = depth
self.parents = ["Global"]

View File

@ -1,6 +1,6 @@
import gc
import sys
from typing import NamedTuple, Tuple, List, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import types
import weakref
import json
@ -100,7 +100,7 @@ def annotated_references(obj):
need for a list. Descriptions are currently strings.
"""
references = {}
references: Dict[int, List[str]] = {}
def add_reference(name, obj):
references.setdefault(id(obj), []).append(name)
@ -272,7 +272,7 @@ def create_graph(objects, *, context=None, filter=None):
filter = is_cuda_tensor
nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
node_referrers = [[] for obj in objects]
node_referrers: List[List[int]] = [[] for obj in objects]
id_to_node = {id(obj): i for i, obj in enumerate(objects)}
for obj in objects:
@ -299,8 +299,8 @@ def create_graph(objects, *, context=None, filter=None):
to_keep.add(idx)
referrers = node_referrers[idx]
to_search.extend(referrers)
id_to_filtered_id = {}
filtered = []
id_to_filtered_id: Dict[int, int] = {}
filtered: List[Any] = []
for i, n in enumerate(nodes):
if i in to_keep:
id_to_filtered_id[i] = len(id_to_filtered_id)

View File

@ -50,7 +50,7 @@ class WeakIdRef(weakref.ref):
# cache the id of the key as we know this is definitely the hash
# method
self._id = id(key)
super().__init__(key, callback)
super().__init__(key, callback) # type: ignore[call-arg]
def __call__(self):
r = super().__call__()

View File

@ -2,12 +2,12 @@
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
try:
from yaml import CSafeDumper as Dumper
except ImportError:
from yaml import SafeDumper as Dumper # type: ignore[misc]
from yaml import SafeDumper as Dumper # type: ignore[assignment, misc]
YamlDumper = Dumper