mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional) Plus few real fixes: - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi` - Add missing return statement to `torch._export. deserialize_graph` - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights` - TODO (in followup PR): - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983 Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
This commit is contained in:
committed by
PyTorch MergeBot
parent
f73757d551
commit
634659e262
@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==0.960
|
||||
mypy==1.4.1
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
#Description: linter
|
||||
#Pinned versions: 0.960
|
||||
#Pinned versions: 1.4.1
|
||||
#test that import: test_typing.py, test_type_hints.py
|
||||
|
||||
networkx==2.8.8
|
||||
|
@ -152,7 +152,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.24.3',
|
||||
'expecttest==0.1.3',
|
||||
'mypy==0.960',
|
||||
'mypy==1.4.1',
|
||||
'types-requests==2.27.25',
|
||||
'types-PyYAML==6.0.7',
|
||||
'types-tabulate==0.8.8',
|
||||
|
@ -49,7 +49,7 @@ try:
|
||||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
def write(filename, s):
|
||||
|
@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def):
|
||||
wall_time=step, step=step, graph_def=graph_def.SerializeToString())
|
||||
|
||||
|
||||
@cli.command("tensorboard-graphs")
|
||||
@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined]
|
||||
@click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False),
|
||||
multiple=True)
|
||||
@click.option("--tf-dir", type=click.Path(exists=True))
|
||||
@ -129,7 +129,7 @@ def tensorboard_graphs(c2_netdef, tf_dir):
|
||||
log.info("Wrote %s graphs to logdir %s", len(events), tf_dir)
|
||||
|
||||
|
||||
@cli.command("tensorboard-events")
|
||||
@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined]
|
||||
@click.option("--c2-dir", type=click.Path(exists=True, file_okay=False),
|
||||
help="Root directory of the Caffe2 run")
|
||||
@click.option("--tf-dir", type=click.Path(writable=True),
|
||||
@ -209,4 +209,4 @@ def tensorboard_events(c2_dir, tf_dir):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
cli() # type: ignore[misc]
|
||||
|
@ -21,7 +21,7 @@ class TestFuture(TestCase):
|
||||
error_msg = "Intentional Value Error"
|
||||
value_error = ValueError(error_msg)
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
# Set exception
|
||||
f.set_exception(value_error)
|
||||
# Exception should throw on wait
|
||||
@ -29,7 +29,7 @@ class TestFuture(TestCase):
|
||||
f.wait()
|
||||
|
||||
# Exception should also throw on value
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
f.set_exception(value_error)
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.value()
|
||||
@ -37,7 +37,7 @@ class TestFuture(TestCase):
|
||||
def cb(fut):
|
||||
fut.value()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
f.set_exception(value_error)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
@ -54,7 +54,7 @@ class TestFuture(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.wait()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
t = threading.Thread(target=wait_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error)
|
||||
@ -68,7 +68,7 @@ class TestFuture(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
fut.wait()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
t = threading.Thread(target=then_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error)
|
||||
|
@ -29,7 +29,7 @@ from itertools import product, combinations, permutations
|
||||
from functools import partial
|
||||
from torch import multiprocessing as mp
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
|
||||
@ -8469,7 +8469,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
except RuntimeError:
|
||||
pass
|
||||
with mp.Pool(1) as pool:
|
||||
out: list = pool.map(method, [arg])
|
||||
out = pool.map(method, [arg])
|
||||
self.assertTrue(out[0])
|
||||
|
||||
def _test_multinomial_invalid_probs(probs):
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
|
||||
from package.oss.cov_json import get_json_report
|
||||
from package.oss.init import initialization
|
||||
from package.tool.summarize_jsons import summarize_jsons
|
||||
from package.util.setting import TestPlatform
|
||||
from package.util.utils import print_time
|
||||
from package.oss.cov_json import get_json_report # type: ignore[import]
|
||||
from package.oss.init import initialization # type: ignore[import]
|
||||
from package.tool.summarize_jsons import summarize_jsons # type: ignore[import]
|
||||
from package.util.setting import TestPlatform # type: ignore[import]
|
||||
from package.util.utils import print_time # type: ignore[import]
|
||||
|
||||
|
||||
def report_coverage() -> None:
|
||||
|
@ -45,7 +45,7 @@ def transform_file_name(
|
||||
return file_path[file_path.find(folder) :]
|
||||
# remove pytorch base folder path
|
||||
if platform == TestPlatform.OSS:
|
||||
from package.oss.utils import get_pytorch_folder
|
||||
from package.oss.utils import get_pytorch_folder # type: ignore[import]
|
||||
|
||||
pytorch_foler = get_pytorch_folder()
|
||||
assert file_path.startswith(pytorch_foler)
|
||||
|
@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str:
|
||||
|
||||
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
|
||||
if platform == TestPlatform.OSS:
|
||||
from package.oss.utils import detect_compiler_type # type: ignore[misc]
|
||||
from package.oss.utils import ( # type: ignore[assignment, import, misc]
|
||||
detect_compiler_type,
|
||||
)
|
||||
|
||||
cov_type = detect_compiler_type() # type: ignore[call-arg]
|
||||
else:
|
||||
@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType:
|
||||
cov_type = detect_compiler_type()
|
||||
|
||||
check_compiler_type(cov_type)
|
||||
return cov_type
|
||||
return cov_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_test_name_from_whole_path(path: str) -> str:
|
||||
|
@ -20,7 +20,7 @@ from yaml.nodes import MappingNode
|
||||
try:
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader # type: ignore[misc]
|
||||
from yaml import Loader # type: ignore[assignment, misc]
|
||||
|
||||
H_NAME = "spv.h"
|
||||
CPP_NAME = "spv.cpp"
|
||||
|
@ -16,7 +16,7 @@ from yaml import dump, load
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
|
@ -11,7 +11,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
|
||||
|
@ -10,7 +10,7 @@ try:
|
||||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as YamlLoader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
|
||||
|
||||
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
|
||||
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
|
||||
|
@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
|
||||
def _last_executed_optimized_graph() -> Graph: ...
|
||||
def parse_type_comment(comment: str) -> Decl: ...
|
||||
def _get_upgraders_map_size() -> _int: ...
|
||||
def _get_upgraders_entry_map() -> Dict[str, str]: ...
|
||||
def _dump_upgraders_map() -> Dict[str, str]: ...
|
||||
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
||||
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
|
||||
|
@ -417,7 +417,7 @@ def sym_int(a):
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif isinstance(a, SymFloat):
|
||||
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type]
|
||||
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
|
||||
return py_int(a) # type: ignore[operator]
|
||||
|
||||
def sym_max(a, b):
|
||||
@ -1316,7 +1316,7 @@ if TYPE_CHECKING:
|
||||
# Some type signatures pulled in from _VariableFunctions here clash with
|
||||
# signatures already imported. For now these clashes are ignored; see
|
||||
# PR #43339 for details.
|
||||
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
|
||||
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
|
||||
# Fixup segment_reduce visibility
|
||||
_segment_reduce = segment_reduce
|
||||
del segment_reduce
|
||||
|
@ -10,20 +10,20 @@ class _Union:
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
assert len(kwargs) == 1
|
||||
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
|
||||
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
|
||||
|
||||
def __post_init__(self):
|
||||
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1
|
||||
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
|
||||
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
|
||||
assert val is not None
|
||||
return val
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
|
||||
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
|
||||
assert val_type is not None
|
||||
return val_type
|
||||
|
||||
|
@ -9,7 +9,7 @@ import typing
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -749,7 +749,7 @@ class GraphModuleDeserializer:
|
||||
self.module = torch.nn.Module()
|
||||
|
||||
@contextmanager
|
||||
def save_graph_module(self) -> None:
|
||||
def save_graph_module(self) -> Iterator[None]:
|
||||
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
@ -773,7 +773,7 @@ class GraphModuleDeserializer:
|
||||
|
||||
if vr := self.symbol_name_to_range.get(val.expr_str):
|
||||
symbolic_shapes._constrain_symbol_range(
|
||||
self.shape_env, sym, vr.lower, vr.upper
|
||||
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return self.shape_env.create_symintnode(sym, hint=val.hint)
|
||||
@ -855,6 +855,7 @@ class GraphModuleDeserializer:
|
||||
output_node.meta["val"] = tuple(
|
||||
arg.meta["val"] for arg in output_node.args[0]
|
||||
)
|
||||
return output_node
|
||||
|
||||
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
|
||||
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
|
||||
@ -1050,7 +1051,7 @@ class GraphModuleDeserializer:
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
|
||||
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
|
||||
ret = {}
|
||||
ret: Dict[str, Any] = {}
|
||||
if stack_trace := metadata.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
|
||||
|
@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]:
|
||||
"""Getting upgraders entry map and operator version map and merge them into one dict."""
|
||||
upgraders = torch._C._get_upgraders_entry_map()
|
||||
op_version_map = torch._C._get_operator_version_map()
|
||||
output = defaultdict(tuple)
|
||||
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
|
||||
for opname, entry_list in op_version_map.items():
|
||||
if not entry_list:
|
||||
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
|
||||
|
@ -27,7 +27,7 @@ def load_tensor_reader(loc):
|
||||
|
||||
def register_debug_prims():
|
||||
@custom_op("debugprims::load_tensor")
|
||||
def load_tensor(
|
||||
def load_tensor( # type: ignore[empty-body]
|
||||
name: str,
|
||||
size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
|
@ -10,7 +10,7 @@ from torch._prims_common import (
|
||||
import torch._prims_common as utils
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from typing import Callable, Sequence, Tuple, NamedTuple, overload
|
||||
from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload
|
||||
import inspect
|
||||
from functools import wraps
|
||||
import warnings
|
||||
@ -97,7 +97,7 @@ class elementwise_type_promotion_wrapper:
|
||||
self,
|
||||
*,
|
||||
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
type_promoting_args: Sequence[str] = None,
|
||||
type_promoting_args: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.type_promoting_arg_names = type_promoting_args
|
||||
self.type_promotion_kind = type_promotion_kind
|
||||
|
@ -856,7 +856,7 @@ def prepare_n_shadows_model(
|
||||
create_n_transformed_and_logged_copies_of_subgraph(
|
||||
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
|
||||
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
|
||||
custom_prepare_fn, custom_prepare_kwargs
|
||||
custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mt
|
||||
|
@ -443,7 +443,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
example_inputs: Any,
|
||||
last_added_shadow_node_list: List[Optional[Node]],
|
||||
custom_prepare_fn: Optional[Callable] = None,
|
||||
custom_prepare_kwargs: Dict[str, Any] = None,
|
||||
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given a subgraph in `mt` and a subgraph candidate idx, inserts the
|
||||
@ -575,7 +575,7 @@ def create_n_transformed_and_logged_copies_of_subgraph(
|
||||
qconfig_mappings: List[QConfigMapping],
|
||||
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
|
||||
custom_prepare_fn: Optional[Callable] = None,
|
||||
custom_prepare_kwargs: Dict[str, Any] = None,
|
||||
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given a model `mt` and a subgraph_idx, creates the needed copies
|
||||
@ -756,7 +756,7 @@ def create_add_loggers_graph(
|
||||
create_n_transformed_and_logged_copies_of_subgraph(
|
||||
model, cur_subgraph_idx, match_name, maybe_subgraph,
|
||||
[qconfig_mapping], [node_name_to_qconfig],
|
||||
None, None
|
||||
None, None # type: ignore[arg-type]
|
||||
)
|
||||
# find the created shadow module and record it so we
|
||||
# can find it easily in step 2
|
||||
|
@ -78,7 +78,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
|
||||
res.append(param_value)
|
||||
return res
|
||||
else:
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
||||
res = []
|
||||
for weight_value in mod._all_weight_values:
|
||||
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from dlrm_s_pytorch import unpack_batch # type: ignore[import]
|
||||
import numpy as np # type: ignore[import]
|
||||
import sklearn # type: ignore[import]
|
||||
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model
|
||||
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model # type: ignore[import]
|
||||
import pandas as pd # type: ignore[import]
|
||||
import argparse
|
||||
|
||||
|
@ -458,8 +458,10 @@ def _update_special_qspecs_after_replacement(
|
||||
if isinstance(edge_or_node, Node):
|
||||
_node = edge_or_node
|
||||
return original_to_replacement_node.get(_node, _node)
|
||||
elif isinstance(edge_or_node, Tuple[Node, Node]):
|
||||
src, dest = edge_or_node
|
||||
# TODO: It's really should be
|
||||
# isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node)
|
||||
elif isinstance(edge_or_node, Tuple[Node, Node]): # type: ignore[arg-type]
|
||||
src, dest = edge_or_node # type: ignore[misc]
|
||||
return (
|
||||
original_to_replacement_node.get(src, src),
|
||||
original_to_replacement_node.get(dest, dest),
|
||||
|
@ -58,13 +58,13 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
for conv_op, add_op, relu_op in conv_add_relu_options:
|
||||
if add_op is None:
|
||||
# Append Conv ReLU
|
||||
supported_operators["conv2d"].append([conv_op, relu_op])
|
||||
supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item]
|
||||
elif relu_op is None:
|
||||
# Append Conv Add
|
||||
supported_operators["conv2d"].append([conv_op, add_op])
|
||||
supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item]
|
||||
else:
|
||||
# Append Conv Add ReLU
|
||||
supported_operators["conv2d"].append([conv_op, add_op, relu_op])
|
||||
supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item]
|
||||
|
||||
return copy.deepcopy(supported_operators)
|
||||
|
||||
@ -222,17 +222,17 @@ class X86InductorQuantizer(Quantizer):
|
||||
"""
|
||||
conv_gemm_node_idx = None
|
||||
extra_input_node_idx = None
|
||||
if (binary_node.args[0].op == "call_function") and (
|
||||
if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr]
|
||||
binary_node.args[0] == conv_gemm_node
|
||||
):
|
||||
conv_gemm_node_idx = 0
|
||||
extra_input_node_idx = 1
|
||||
elif (binary_node.args[1].op == "call_function") and (
|
||||
elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr]
|
||||
binary_node.args[1] == conv_gemm_node
|
||||
):
|
||||
conv_gemm_node_idx = 1
|
||||
extra_input_node_idx = 0
|
||||
extra_input_node = binary_node.args[extra_input_node_idx]
|
||||
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
|
||||
assert isinstance(extra_input_node, Node)
|
||||
return conv_gemm_node_idx, extra_input_node_idx
|
||||
|
||||
|
@ -119,9 +119,9 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_remove_tensor_overload_for_qdq_ops(model)
|
||||
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
||||
pattern = _get_aten_graph_module(pattern, example_inputs)
|
||||
_remove_tensor_overload_for_qdq_ops(pattern)
|
||||
replacement = _get_aten_graph_module(replacement, example_inputs)
|
||||
_remove_tensor_overload_for_qdq_ops(replacement)
|
||||
pattern = _get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
_remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||
replacement = _get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
_remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||
matches = replace_pattern(model, pattern, replacement)
|
||||
return model
|
||||
|
@ -44,7 +44,7 @@ class APoTQuantizer():
|
||||
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
|
||||
result = TensorAPoT(self, tensor2quantize)
|
||||
result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment]
|
||||
|
||||
return result
|
||||
|
||||
@ -83,7 +83,7 @@ class APoTQuantizer():
|
||||
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
|
||||
levels_lst = list(self.quantization_levels)
|
||||
|
||||
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
|
||||
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg]
|
||||
|
||||
return result
|
||||
|
||||
|
@ -85,9 +85,9 @@ def _find_matches(
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
patterns: Dict[Pattern, QuantizeHandler],
|
||||
root_node_getter_mapping: Dict[Pattern, Callable],
|
||||
standalone_module_names: List[str] = None,
|
||||
standalone_module_classes: List[Type] = None,
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
|
||||
standalone_module_names: Optional[List[str]] = None,
|
||||
standalone_module_classes: Optional[List[Type]] = None,
|
||||
custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
|
||||
"""
|
||||
Matches the nodes in the input graph to quantization patterns, and
|
||||
outputs the information needed to quantize them in future steps.
|
||||
|
@ -18,7 +18,7 @@ from torch.ao.quantization.utils import (
|
||||
)
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, List, Type
|
||||
from typing import Callable, Dict, List, Type, Optional
|
||||
|
||||
__all__ = [
|
||||
"QuantizeHandler",
|
||||
@ -52,7 +52,7 @@ class QuantizeHandler(ABC):
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None,
|
||||
root_node_getter: Optional[Callable] = None,
|
||||
is_custom_module=False,
|
||||
is_standalone_module=False):
|
||||
""" Records pattern information in __init__, which will be used
|
||||
@ -113,7 +113,7 @@ def _get_quantize_handler_cls(
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None):
|
||||
root_node_getter: Optional[Callable] = None):
|
||||
super().__init__(node_pattern, modules, root_node_getter)
|
||||
if num_tensor_args_to_observation_type:
|
||||
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
|
||||
|
@ -455,7 +455,7 @@ def _profile_to_snapshot(profile):
|
||||
device = to_device(tensor_key.device)
|
||||
addr = tensor_key.storage.ptr
|
||||
|
||||
seg = snapshot['segments'][device]
|
||||
seg = snapshot['segments'][device] # type: ignore[index]
|
||||
if seg['address'] is None or seg['address'] > addr:
|
||||
seg['address'] = addr
|
||||
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
|
||||
@ -465,12 +465,12 @@ def _profile_to_snapshot(profile):
|
||||
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
|
||||
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
|
||||
if during_trace:
|
||||
snapshot['device_traces'][device].append(r)
|
||||
snapshot['device_traces'][device].append(r) # type: ignore[index]
|
||||
return r
|
||||
|
||||
def free(alloc, device):
|
||||
for e in ('free_requested', 'free_completed'):
|
||||
snapshot['device_traces'][device].append({'action': e,
|
||||
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
|
||||
'addr': alloc['addr'],
|
||||
'size': alloc['size'],
|
||||
'stream': 0,
|
||||
@ -499,7 +499,7 @@ def _profile_to_snapshot(profile):
|
||||
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
|
||||
for (tensor_key, version), event in kv_to_elem.items()]
|
||||
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
|
||||
seg = snapshot['segments'][device]
|
||||
seg = snapshot['segments'][device] # type: ignore[index]
|
||||
last_addr = seg['address']
|
||||
for _, addr, size, frames in blocks:
|
||||
if last_addr < addr:
|
||||
@ -510,8 +510,8 @@ def _profile_to_snapshot(profile):
|
||||
if last_addr < seg['total_size']:
|
||||
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
|
||||
|
||||
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]
|
||||
for seg in snapshot['segments']:
|
||||
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
|
||||
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
|
||||
seg['total_size'] -= seg['address']
|
||||
if not seg['blocks']:
|
||||
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import cast, Iterable, List, Union
|
||||
from typing import Iterable, List, Union
|
||||
from . import _lazy_init, _lazy_call, device_count, current_device
|
||||
from .. import Tensor
|
||||
|
||||
@ -56,7 +56,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu
|
||||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = cast(torch.device, device).index
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
|
@ -732,7 +732,7 @@ class ShardedTensor(ShardedTensorBase):
|
||||
local_tensor: torch.Tensor,
|
||||
sharding_spec: shard_spec.ShardingSpec,
|
||||
*global_size: Sequence[int],
|
||||
process_group: dist.ProcessGroup = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
init_rrefs=False,
|
||||
) -> "ShardedTensor":
|
||||
"""
|
||||
|
@ -389,7 +389,7 @@ def _compile(
|
||||
# can trace operations applied to them.
|
||||
def stateless_func(func, params, buffers, named_states, args, kwargs):
|
||||
with stateless._reparametrize_module(
|
||||
cast(nn.Module, mod), {**params, **buffers}
|
||||
mod, {**params, **buffers}
|
||||
), _rematerialize_optimizer(
|
||||
opt, named_states, params
|
||||
) if opt else nullcontext():
|
||||
|
@ -133,7 +133,7 @@ def gen_model_param_in_submesh(model: nn.Module, sub_mesh: DeviceMesh) -> nn.Mod
|
||||
)
|
||||
|
||||
|
||||
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module:
|
||||
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: # type: ignore[empty-body]
|
||||
"""
|
||||
checkpoint save/load models with DTensor parameters
|
||||
"""
|
||||
|
@ -40,7 +40,7 @@ def register_op_strategy(op):
|
||||
def as_list(
|
||||
x: Union[List[object], object]
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:
|
||||
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||
# which is an object but treated as a list by the tracer. Therefore, keep
|
||||
# `immutable_list` intact here as well.
|
||||
|
@ -160,7 +160,7 @@ def all_gather(
|
||||
ret_list.append(sp.payload)
|
||||
|
||||
if len(exception_list) > 0:
|
||||
raise RuntimeError(
|
||||
raise RuntimeError( # type: ignore[misc]
|
||||
error_msg, exception_list) from exception_list[0]
|
||||
return ret_list
|
||||
else:
|
||||
@ -168,7 +168,7 @@ def all_gather(
|
||||
raise RuntimeError(
|
||||
f"all_gather failed with exception {sync_obj.exception}",
|
||||
) from sync_obj.exception
|
||||
return [sync_obj.payload]
|
||||
return [sync_obj.payload] # type: ignore[list-item]
|
||||
|
||||
|
||||
# Note: use Any for typing for now so users can pass in
|
||||
|
@ -1433,7 +1433,7 @@ def _register_post_backward_reshard_only_hooks(
|
||||
hook_handle = register_multi_grad_hook(
|
||||
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
|
||||
)
|
||||
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined]
|
||||
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
|
||||
|
||||
|
||||
@no_type_check
|
||||
@ -1468,11 +1468,11 @@ def _wait_for_computation_stream(
|
||||
For example, this should be called in the FSDP root's pre-forward to
|
||||
respect optimizer step computation.
|
||||
"""
|
||||
unshard_stream.wait_stream(computation_stream)
|
||||
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
|
||||
# Having the pre-all-gather stream wait for the current stream even if we
|
||||
# do not leverage the pre-all-gather stream is tolerable since this only
|
||||
# runs once per iteration
|
||||
pre_unshard_stream.wait_stream(computation_stream)
|
||||
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(
|
||||
|
@ -167,7 +167,7 @@ class _ExecOrderTracer:
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
|
||||
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
|
||||
) -> torch.fx.Proxy:
|
||||
"""
|
||||
Overrides ``create_proxy`` to save execution information to
|
||||
|
@ -96,7 +96,7 @@ def _auto_wrap(
|
||||
)
|
||||
recursive_wrap_kwargs["auto_wrap_policy"] = policy
|
||||
_warn_on_overridden_mixed_precision(overridden_module_classes)
|
||||
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs)
|
||||
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _check_nested_wrapping(root_module: nn.Module):
|
||||
|
@ -358,7 +358,7 @@ class _RemoteModule(nn.Module):
|
||||
def bfloat16(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.bfloat16.__name__)
|
||||
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[return]
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
|
||||
_raise_not_supported(self.to.__name__)
|
||||
|
||||
def register_backward_hook( # type: ignore[return]
|
||||
@ -377,7 +377,7 @@ class _RemoteModule(nn.Module):
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
||||
|
||||
def register_forward_hook( # type: ignore[return]
|
||||
def register_forward_hook( # type: ignore[return, override]
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
|
||||
|
@ -162,7 +162,7 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
return self._optimizer.step(closure=closure)
|
||||
|
||||
@property
|
||||
def state(self) -> Mapping[torch.Tensor, Any]:
|
||||
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
|
||||
return self._optimizer.state
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
||||
|
@ -49,7 +49,7 @@ def input_reshard(
|
||||
|
||||
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
|
||||
nonlocal cx
|
||||
cx.__exit__() # type: ignore[name-defined]
|
||||
cx.__exit__() # type: ignore[name-defined, union-attr]
|
||||
|
||||
if input_reshard_dim is None:
|
||||
return module
|
||||
|
@ -23,7 +23,7 @@ class FoldedGraphModule(torch.fx.GraphModule):
|
||||
root: torch.nn.Module,
|
||||
graph: torch.fx.Graph,
|
||||
const_subgraph: Optional[torch.fx.Graph] = None,
|
||||
fx_const_folded_attrs_name: str = None,
|
||||
fx_const_folded_attrs_name: Optional[str] = None,
|
||||
device_for_folded_attrs: str = "cuda",
|
||||
):
|
||||
super().__init__(root, graph)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# type: ignore[attr-defined]
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from .core import unify, reify # noqa: F403
|
||||
from .more import unifiable # noqa: F403
|
||||
from .variable import var, isvar, vars, variables, Var # noqa: F403
|
||||
|
@ -202,7 +202,7 @@ def replace_pattern_with_filters(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
"""
|
||||
@ -222,7 +222,7 @@ def _replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
|
||||
|
@ -103,7 +103,7 @@ class Sequential(Module):
|
||||
for idx, module in enumerate(args):
|
||||
self.add_module(str(idx), module)
|
||||
|
||||
def _get_item_by_idx(self, iterator, idx) -> T:
|
||||
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
|
||||
"""Get the idx-th item of the iterator"""
|
||||
size = len(self)
|
||||
idx = operator.index(idx)
|
||||
|
@ -51,7 +51,7 @@ class _ConvNd(Module):
|
||||
'out_channels', 'kernel_size']
|
||||
__annotations__ = {'bias': Optional[torch.Tensor]}
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
in_channels: int
|
||||
|
@ -340,7 +340,7 @@ class ProtobufExportOutputSerializer:
|
||||
) -> None:
|
||||
import onnx
|
||||
|
||||
if not isinstance(export_output.model_proto, onnx.ModelProto):
|
||||
if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
|
||||
raise ValueError("export_output.ModelProto is not an onnx.ModelProto")
|
||||
destination.write(export_output.model_proto.SerializeToString())
|
||||
|
||||
@ -348,7 +348,7 @@ class ProtobufExportOutputSerializer:
|
||||
class ExportOutput:
|
||||
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""
|
||||
|
||||
_model_proto: Final[onnx.ModelProto]
|
||||
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined]
|
||||
_input_adapter: Final[io_adapter.InputAdapter]
|
||||
_output_adapter: Final[io_adapter.OutputAdapter]
|
||||
_diagnostic_context: Final[infra.DiagnosticContext]
|
||||
@ -356,7 +356,7 @@ class ExportOutput:
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
model_proto: onnx.ModelProto,
|
||||
model_proto: onnx.ModelProto, # type: ignore[name-defined]
|
||||
input_adapter: io_adapter.InputAdapter,
|
||||
output_adapter: io_adapter.OutputAdapter,
|
||||
diagnostic_context: infra.DiagnosticContext,
|
||||
@ -367,7 +367,7 @@ class ExportOutput:
|
||||
self._diagnostic_context = diagnostic_context
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto:
|
||||
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
|
||||
"""The exported ONNX model as an ``onnx.ModelProto``."""
|
||||
|
||||
return self._model_proto
|
||||
|
@ -193,7 +193,7 @@ class DynamoExport(exporter.FXGraphExtractor):
|
||||
wrapped_model,
|
||||
*model_args,
|
||||
tracing_mode=fx_mode,
|
||||
fake_mode=fake_mode,
|
||||
fake_mode=fake_mode, # type: ignore[arg-type]
|
||||
**model_kwargs,
|
||||
)
|
||||
del graph_guard # Unused
|
||||
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
@_beartype.beartype
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor, name: str, location: str, basepath: str
|
||||
) -> onnx.TensorProto:
|
||||
) -> onnx.TensorProto: # type: ignore[name-defined]
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
|
||||
@ -38,13 +38,13 @@ def _create_tensor_proto_with_external_data(
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
tensor_proto = onnx.TensorProto()
|
||||
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
|
||||
tensor.dtype
|
||||
).onnx_type()
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
|
||||
|
||||
# Settings for saving one tensor per file.
|
||||
# Offset is zero because there is no other tensor in the same file.
|
||||
@ -86,7 +86,7 @@ def save_model_with_external_data(
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_load_paths: Tuple[Union[str, io.BytesIO], ...],
|
||||
onnx_model: onnx.ModelProto,
|
||||
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
|
||||
rename_initializer: bool = False,
|
||||
) -> None:
|
||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||
@ -121,7 +121,7 @@ def save_model_with_external_data(
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
onnx_model_with_initializers = onnx.ModelProto()
|
||||
onnx_model_with_initializers = onnx.ModelProto() # type: ignore[attr-defined]
|
||||
onnx_model_with_initializers.CopyFrom(onnx_model)
|
||||
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
||||
|
||||
@ -159,4 +159,4 @@ def save_model_with_external_data(
|
||||
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
||||
|
||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) # type: ignore[attr-defined]
|
||||
|
@ -62,7 +62,7 @@ def export_as_test_case(
|
||||
shutil.rmtree(data_set_dir)
|
||||
os.makedirs(data_set_dir)
|
||||
|
||||
proto = onnx.load_model_from_string(model_bytes)
|
||||
proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
|
||||
|
||||
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
|
||||
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
|
||||
@ -112,12 +112,12 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
|
||||
inputs = {}
|
||||
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
|
||||
for input_file in input_files:
|
||||
tensor = onnx.load_tensor(input_file)
|
||||
tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
|
||||
inputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
outputs = {}
|
||||
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
|
||||
for output_file in output_files:
|
||||
tensor = onnx.load_tensor(output_file)
|
||||
tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
|
||||
outputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
|
||||
return model_bytes, inputs, outputs
|
||||
@ -227,7 +227,7 @@ def _add_onnxscript_fn(
|
||||
# size > 2GB, and if it for some reason did not, the model would fail on
|
||||
# serialization anyway in terms of the protobuf limitation. So we don't
|
||||
# need to worry about > 2GB model getting here.
|
||||
model_proto = onnx.load_model_from_string(model_bytes)
|
||||
model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
|
||||
|
||||
# Iterate graph nodes to insert only the included custom
|
||||
# function_proto into model_proto
|
||||
|
@ -206,14 +206,14 @@ def _ort_session(
|
||||
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
|
||||
try:
|
||||
import onnx
|
||||
from onnx import reference as onnx_reference
|
||||
from onnx import reference as onnx_reference # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
raise ImportError("onnx >= 1.13 is required for reference evaluator.")
|
||||
|
||||
proto = (
|
||||
onnx.load(model)
|
||||
onnx.load(model) # type: ignore[attr-defined]
|
||||
if isinstance(model, str)
|
||||
else onnx.load_model_from_string(model.getvalue())
|
||||
else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined]
|
||||
)
|
||||
onnx_session = onnx_reference.ReferenceEvaluator(proto)
|
||||
return onnx_session
|
||||
|
@ -225,9 +225,9 @@ class Optimizer:
|
||||
"""
|
||||
|
||||
OptimizerPreHook: TypeAlias = Callable[
|
||||
[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]] # type: ignore[valid-type]
|
||||
[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]] # type: ignore[misc, valid-type]
|
||||
]
|
||||
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[valid-type]
|
||||
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc, valid-type]
|
||||
|
||||
_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
|
||||
|
@ -603,7 +603,7 @@ def input_dtypes(event: _ProfilerEvent):
|
||||
def report_all_anti_patterns(prof,
|
||||
should_benchmark: bool = False,
|
||||
print_enable: bool = True,
|
||||
json_report_dir: str = None):
|
||||
json_report_dir: Optional[str] = None):
|
||||
report_dict: Dict = {}
|
||||
anti_patterns = [
|
||||
ExtraCUDACopyPattern(prof, should_benchmark),
|
||||
|
@ -202,7 +202,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
self.compressed_tensor = compressed_tensor
|
||||
self.transposed = transposed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
||||
Returns:
|
||||
|
@ -3,7 +3,7 @@ import io
|
||||
import torch
|
||||
from ._utils import _type, _cuda, _hpu
|
||||
from torch.types import Storage
|
||||
from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict
|
||||
from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union
|
||||
import copy
|
||||
import collections
|
||||
from functools import lru_cache
|
||||
@ -27,51 +27,51 @@ class _StorageBase:
|
||||
device: torch.device
|
||||
|
||||
def __init__(self, *args, **kwargs): ... # noqa: E704
|
||||
def __len__(self) -> int: ... # noqa: E704
|
||||
def __len__(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
def __getitem__(self, idx): ... # noqa: E704
|
||||
def __setitem__(self, *args, **kwargs): ... # noqa: E704
|
||||
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
|
||||
def new(self) -> T: ... # noqa: E704
|
||||
def nbytes(self) -> int: ... # noqa: E704
|
||||
def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def new(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def nbytes(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
|
||||
def size(self) -> int:
|
||||
return self.nbytes()
|
||||
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
||||
def element_size(self) -> int: ... # noqa: E704
|
||||
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704
|
||||
|
||||
def get_device(self) -> int:
|
||||
return self.device.index
|
||||
|
||||
def data_ptr(self) -> int: ... # noqa: E704
|
||||
def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
|
||||
# Defined in torch/csrc/generic/StorageSharing.cpp
|
||||
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def from_buffer(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
|
||||
def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _shared_decref(self) -> T: ... # noqa: E704
|
||||
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _write_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def resize_(self, size: int): ... # noqa: E704
|
||||
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
|
||||
def is_shared(self) -> bool: ... # noqa: E704
|
||||
def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
|
||||
@ -80,9 +80,9 @@ class _StorageBase:
|
||||
@property
|
||||
def is_hpu(self): ... # noqa: E704
|
||||
@classmethod
|
||||
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
|
||||
def from_file(cls, filename, shared, nbytes) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
@classmethod
|
||||
def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _expired(cls, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _byteswap(self, *args, **kwargs): ... # noqa: E704
|
||||
|
||||
def __str__(self):
|
||||
@ -712,12 +712,12 @@ class TypedStorage:
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
|
||||
return tmp_tensor[idx_wrapped].item()
|
||||
|
||||
def copy_(self, source: T, non_blocking: bool = None):
|
||||
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
|
||||
_warn_typed_storage_removal()
|
||||
if isinstance(source, TypedStorage):
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
|
||||
else:
|
||||
self._untyped_storage.copy_(source, non_blocking)
|
||||
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
|
||||
return self
|
||||
|
||||
def nbytes(self):
|
||||
@ -728,7 +728,7 @@ class TypedStorage:
|
||||
def _nbytes(self):
|
||||
return self._untyped_storage.nbytes()
|
||||
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
|
||||
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]:
|
||||
_warn_typed_storage_removal()
|
||||
if dtype is None:
|
||||
legacy_class = self._get_legacy_storage_class()
|
||||
@ -741,14 +741,14 @@ class TypedStorage:
|
||||
else:
|
||||
return self._untyped_storage.type(dtype, non_blocking)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create HPU storage with quantized dtype")
|
||||
|
@ -43,37 +43,37 @@ class Storage:
|
||||
dtype: torch.dtype
|
||||
_torch_load_uninitialized: bool
|
||||
|
||||
def __deepcopy__(self, memo) -> 'Storage':
|
||||
def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _new_shared(self, int) -> 'Storage':
|
||||
def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
|
||||
...
|
||||
|
||||
def element_size(self) -> int:
|
||||
def element_size(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def is_shared(self) -> bool:
|
||||
def is_shared(self) -> bool: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def share_memory_(self) -> 'Storage':
|
||||
def share_memory_(self) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def nbytes(self) -> int:
|
||||
def nbytes(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def cpu(self) -> 'Storage':
|
||||
def cpu(self) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def data_ptr(self) -> int:
|
||||
def data_ptr(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
|
||||
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
|
||||
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
...
|
||||
|
@ -100,7 +100,7 @@ def report_compile_source_on_error():
|
||||
# specifically _PyCode_InitAddressRange, reveals that
|
||||
# this iterator is initialized from co_linetable and
|
||||
# co_firstfileno. So copy these we must!
|
||||
code = code.replace(
|
||||
code = code.replace( # type: ignore[call-arg]
|
||||
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
|
||||
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
|
||||
)
|
||||
|
@ -1066,7 +1066,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
# pin_memory_thread once it is started.
|
||||
self._pin_memory_thread = pin_memory_thread
|
||||
else:
|
||||
self._data_queue = self._worker_result_queue
|
||||
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
||||
|
||||
# In some rare cases, persistent workers (daemonic processes)
|
||||
# would be terminated before `__del__` of iterator is invoked
|
||||
|
@ -415,7 +415,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
|
||||
|
||||
@functional_datapipe('trace_as_dataframe')
|
||||
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
|
||||
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
|
||||
source_datapipe = None
|
||||
|
||||
# TODO(VitalyFedyunin): Must implement all special functions of datapipes
|
||||
|
@ -376,7 +376,7 @@ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataP
|
||||
self._datapipe_iter = iter(self._datapipe)
|
||||
return self
|
||||
|
||||
def __next__(self) -> T_co:
|
||||
def __next__(self) -> T_co: # type: ignore[type-var]
|
||||
assert self._datapipe_iter is not None
|
||||
return next(self._datapipe_iter)
|
||||
|
||||
|
@ -421,7 +421,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
self.main_datapipe_exhausted = False
|
||||
self._child_stop: List[bool] = [True for _ in range(num_instances)]
|
||||
|
||||
def _find_next(self, instance_id: int) -> T_co:
|
||||
def _find_next(self, instance_id: int) -> T_co: # type: ignore[type-var]
|
||||
while True:
|
||||
if self.main_datapipe_exhausted or self._child_stop[instance_id]:
|
||||
raise StopIteration
|
||||
|
@ -40,7 +40,7 @@ class ConcaterMapDataPipe(MapDataPipe):
|
||||
raise TypeError("Expected all inputs to be `Sized`")
|
||||
self.datapipes = datapipes # type: ignore[assignment]
|
||||
|
||||
def __getitem__(self, index) -> T_co:
|
||||
def __getitem__(self, index) -> T_co: # type: ignore[type-var]
|
||||
offset = 0
|
||||
for dp in self.datapipes:
|
||||
if index - offset < len(dp):
|
||||
|
@ -418,5 +418,5 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
|
||||
if sum(lengths) != len(dataset): # type: ignore[arg-type]
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
|
||||
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
|
||||
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||
|
@ -236,7 +236,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
|
||||
depth: int = 2,
|
||||
display: bool = True,
|
||||
custom_mapping: Dict[Any, Any] = None):
|
||||
custom_mapping: Optional[Dict[Any, Any]] = None):
|
||||
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self.depth = depth
|
||||
self.parents = ["Global"]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import gc
|
||||
import sys
|
||||
from typing import NamedTuple, Tuple, List, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
import types
|
||||
import weakref
|
||||
import json
|
||||
@ -100,7 +100,7 @@ def annotated_references(obj):
|
||||
need for a list. Descriptions are currently strings.
|
||||
|
||||
"""
|
||||
references = {}
|
||||
references: Dict[int, List[str]] = {}
|
||||
|
||||
def add_reference(name, obj):
|
||||
references.setdefault(id(obj), []).append(name)
|
||||
@ -272,7 +272,7 @@ def create_graph(objects, *, context=None, filter=None):
|
||||
filter = is_cuda_tensor
|
||||
|
||||
nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
|
||||
node_referrers = [[] for obj in objects]
|
||||
node_referrers: List[List[int]] = [[] for obj in objects]
|
||||
|
||||
id_to_node = {id(obj): i for i, obj in enumerate(objects)}
|
||||
for obj in objects:
|
||||
@ -299,8 +299,8 @@ def create_graph(objects, *, context=None, filter=None):
|
||||
to_keep.add(idx)
|
||||
referrers = node_referrers[idx]
|
||||
to_search.extend(referrers)
|
||||
id_to_filtered_id = {}
|
||||
filtered = []
|
||||
id_to_filtered_id: Dict[int, int] = {}
|
||||
filtered: List[Any] = []
|
||||
for i, n in enumerate(nodes):
|
||||
if i in to_keep:
|
||||
id_to_filtered_id[i] = len(id_to_filtered_id)
|
||||
|
@ -50,7 +50,7 @@ class WeakIdRef(weakref.ref):
|
||||
# cache the id of the key as we know this is definitely the hash
|
||||
# method
|
||||
self._id = id(key)
|
||||
super().__init__(key, callback)
|
||||
super().__init__(key, callback) # type: ignore[call-arg]
|
||||
|
||||
def __call__(self):
|
||||
r = super().__call__()
|
||||
|
@ -2,12 +2,12 @@
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
try:
|
||||
from yaml import CSafeDumper as Dumper
|
||||
except ImportError:
|
||||
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
||||
from yaml import SafeDumper as Dumper # type: ignore[assignment, misc]
|
||||
YamlDumper = Dumper
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user