[Reland] Update mypy to 1.4.1 (#105227)

This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2023-07-15 20:30:20 +00:00
committed by PyTorch MergeBot
parent 5cd861fcf7
commit 5837e95d30
101 changed files with 289 additions and 258 deletions

View File

@ -160,6 +160,20 @@ case "$image" in
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9)
CUDA_VERSION=12.1.0
CUDNN_VERSION=8
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
PROTOBUF=yes
DB=yes
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-py3-clang7-asan)
ANACONDA_PYTHON_VERSION=3.9
CLANG_VERSION=7

View File

@ -105,5 +105,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
pip_install -r /opt/conda/requirements-docs.txt
fi
# HACK HACK HACK
# gcc-9 for ubuntu-18.04 from http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu
# Pulls llibstdc++6 13.1.0-8ubuntu1~18.04 which is too new for conda
# So remove libstdc++6.so.3.29 installed by https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0
if grep 18.04.6 /etc/issue >/dev/null; then
rm /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/lib/libstdc++.so.6
fi
popd
fi

View File

@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
#Pinned versions:
#test that import:
mypy==0.960
mypy==1.4.1
# Pin MyPy version because new errors are likely to appear with each release
#Description: linter
#Pinned versions: 0.960
#Pinned versions: 1.4.1
#test that import: test_typing.py, test_type_hints.py
networkx==2.8.8

View File

@ -40,6 +40,7 @@ jobs:
- docker-image-name: pytorch-linux-bionic-cuda11.8-cudnn8-py3-gcc7-inductor-benchmarks
- docker-image-name: pytorch-linux-bionic-py3.8-clang9
- docker-image-name: pytorch-linux-bionic-py3.11-clang9
- docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9
- docker-image-name: pytorch-linux-focal-rocm-n-1-py3
- docker-image-name: pytorch-linux-focal-rocm-n-py3
- docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12

View File

@ -305,12 +305,12 @@ jobs:
{ config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" },
]}
linux-bionic-cuda12_1-py3_10-gcc9-bazel-test:
name: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test
linux-focal-cuda12_1-py3_10-gcc9-bazel-test:
name: linux-focal-cuda12.1-py3.10-gcc9-bazel-test
uses: ./.github/workflows/_bazel-build-test.yml
with:
build-environment: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test
docker-image-name: pytorch-linux-bionic-cuda12.1-cudnn8-py3-gcc9
build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9
cuda-version: "12.1"
test-matrix: |
{ include: [

View File

@ -152,7 +152,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'numpy==1.24.3',
'expecttest==0.1.3',
'mypy==0.960',
'mypy==1.4.1',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',

View File

@ -49,7 +49,7 @@ try:
# use faster C loader if available
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
def write(filename, s):

View File

@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def):
wall_time=step, step=step, graph_def=graph_def.SerializeToString())
@cli.command("tensorboard-graphs")
@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined]
@click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False),
multiple=True)
@click.option("--tf-dir", type=click.Path(exists=True))
@ -129,7 +129,7 @@ def tensorboard_graphs(c2_netdef, tf_dir):
log.info("Wrote %s graphs to logdir %s", len(events), tf_dir)
@cli.command("tensorboard-events")
@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined]
@click.option("--c2-dir", type=click.Path(exists=True, file_okay=False),
help="Root directory of the Caffe2 run")
@click.option("--tf-dir", type=click.Path(writable=True),
@ -209,4 +209,4 @@ def tensorboard_events(c2_dir, tf_dir):
if __name__ == "__main__":
cli()
cli() # type: ignore[misc]

View File

@ -21,7 +21,7 @@ class TestFuture(TestCase):
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
# Set exception
f.set_exception(value_error)
# Exception should throw on wait
@ -29,7 +29,7 @@ class TestFuture(TestCase):
f.wait()
# Exception should also throw on value
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value()
@ -37,7 +37,7 @@ class TestFuture(TestCase):
def cb(fut):
fut.value()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
f.set_exception(value_error)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
@ -54,7 +54,7 @@ class TestFuture(TestCase):
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=wait_future, args=(f, ))
t.start()
f.set_exception(value_error)
@ -68,7 +68,7 @@ class TestFuture(TestCase):
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
fut.wait()
f = Future[T]()
f = Future[T]() # type: ignore[valid-type]
t = threading.Thread(target=then_future, args=(f, ))
t.start()
f.set_exception(value_error)

View File

@ -29,7 +29,7 @@ from itertools import product, combinations, permutations
from functools import partial
from torch import multiprocessing as mp
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
@ -8490,7 +8490,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
except RuntimeError:
pass
with mp.Pool(1) as pool:
out: list = pool.map(method, [arg])
out = pool.map(method, [arg])
self.assertTrue(out[0])
def _test_multinomial_invalid_probs(probs):

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3
import time
from package.oss.cov_json import get_json_report
from package.oss.init import initialization
from package.tool.summarize_jsons import summarize_jsons
from package.util.setting import TestPlatform
from package.util.utils import print_time
from package.oss.cov_json import get_json_report # type: ignore[import]
from package.oss.init import initialization # type: ignore[import]
from package.tool.summarize_jsons import summarize_jsons # type: ignore[import]
from package.util.setting import TestPlatform # type: ignore[import]
from package.util.utils import print_time # type: ignore[import]
def report_coverage() -> None:

View File

@ -45,7 +45,7 @@ def transform_file_name(
return file_path[file_path.find(folder) :]
# remove pytorch base folder path
if platform == TestPlatform.OSS:
from package.oss.utils import get_pytorch_folder
from package.oss.utils import get_pytorch_folder # type: ignore[import]
pytorch_foler = get_pytorch_folder()
assert file_path.startswith(pytorch_foler)

View File

@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str:
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
if platform == TestPlatform.OSS:
from package.oss.utils import detect_compiler_type # type: ignore[misc]
from package.oss.utils import ( # type: ignore[assignment, import, misc]
detect_compiler_type,
)
cov_type = detect_compiler_type() # type: ignore[call-arg]
else:
@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType:
cov_type = detect_compiler_type()
check_compiler_type(cov_type)
return cov_type
return cov_type # type: ignore[no-any-return]
def get_test_name_from_whole_path(path: str) -> str:

View File

@ -20,7 +20,7 @@ from yaml.nodes import MappingNode
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader # type: ignore[misc]
from yaml import Loader # type: ignore[assignment, misc]
H_NAME = "spv.h"
CPP_NAME = "spv.cpp"

View File

@ -16,7 +16,7 @@ from yaml import dump, load
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
class LintSeverity(str, Enum):

View File

@ -11,7 +11,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {

View File

@ -10,7 +10,7 @@ try:
# use faster C loader if available
from yaml import CSafeLoader as YamlLoader
except ImportError:
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
TAGS_PATH = "aten/src/ATen/native/tags.yaml"

View File

@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
def _get_upgraders_map_size() -> _int: ...
def _get_upgraders_entry_map() -> Dict[str, str]: ...
def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...

View File

@ -417,7 +417,7 @@ def sym_int(a):
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type]
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
return py_int(a) # type: ignore[operator]
def sym_max(a, b):
@ -1320,7 +1320,7 @@ if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
# Fixup segment_reduce visibility
_segment_reduce = segment_reduce
del segment_reduce

View File

@ -10,20 +10,20 @@ class _Union:
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
def __post_init__(self):
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc]
@property
def value(self):
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val is not None
return val
@property
def type(self):
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
assert val_type is not None
return val_type

View File

@ -9,7 +9,7 @@ import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import sympy
@ -749,7 +749,7 @@ class GraphModuleDeserializer:
self.module = torch.nn.Module()
@contextmanager
def save_graph_module(self) -> None:
def save_graph_module(self) -> Iterator[None]:
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
@ -773,7 +773,7 @@ class GraphModuleDeserializer:
if vr := self.symbol_name_to_range.get(val.expr_str):
symbolic_shapes._constrain_symbol_range(
self.shape_env, sym, vr.lower, vr.upper
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
)
return self.shape_env.create_symintnode(sym, hint=val.hint)
@ -855,6 +855,7 @@ class GraphModuleDeserializer:
output_node.meta["val"] = tuple(
arg.meta["val"] for arg in output_node.args[0]
)
return output_node
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
@ -1050,7 +1051,7 @@ class GraphModuleDeserializer:
self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
ret = {}
ret: Dict[str, Any] = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace

View File

@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]:
"""Getting upgraders entry map and operator version map and merge them into one dict."""
upgraders = torch._C._get_upgraders_entry_map()
op_version_map = torch._C._get_operator_version_map()
output = defaultdict(tuple)
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
for opname, entry_list in op_version_map.items():
if not entry_list:
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")

View File

@ -1,5 +1,5 @@
from collections import Counter
from typing import Any, Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -12,7 +12,7 @@ def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,

View File

@ -1004,7 +1004,7 @@ def _linalg_svd_meta(
A: Tensor,
full_matrices: bool = False,
compute_uv: bool = True,
driver: str = None,
driver: Optional[str] = None,
):
checkIsMatrix(A, "linalg.svd")
checkFloatingOrComplex(A, "linalg.svd")
@ -1147,7 +1147,7 @@ def linalg_solve_triangular_meta(
upper: bool,
left: bool = True,
unitriangular: bool = False,
out: Tensor = None,
out: Optional[Tensor] = None,
) -> Tensor:
if out is None:
out = A.new_empty([0])
@ -4695,8 +4695,8 @@ def upsample_nearest2d_backward(
grad_output: Tensor,
output_size: Sequence[Union[int, torch.types.SymInt]],
input_size: Sequence[Union[int, torch.types.SymInt]],
scales_h: float = None,
scales_w: float = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
):
full_output_size = upsample_common_check(
input_size, output_size, num_spatial_dims=2

View File

@ -331,7 +331,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
def _elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype

View File

@ -1,6 +1,6 @@
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Sequence
from typing import Any, Callable, Dict, Optional, Sequence
from warnings import warn
import torch
@ -111,7 +111,7 @@ class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
@ -161,7 +161,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
@ -374,7 +374,7 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}

View File

@ -27,7 +27,7 @@ def load_tensor_reader(loc):
def register_debug_prims():
@custom_op("debugprims::load_tensor")
def load_tensor(
def load_tensor( # type: ignore[empty-body]
name: str,
size: Sequence[int],
stride: Sequence[int],

View File

@ -10,7 +10,7 @@ from torch._prims_common import (
import torch._prims_common as utils
from torch.utils._pytree import tree_flatten, tree_unflatten
from typing import Callable, Sequence, Tuple, NamedTuple, overload
from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload
import inspect
from functools import wraps
import warnings
@ -97,7 +97,7 @@ class elementwise_type_promotion_wrapper:
self,
*,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
type_promoting_args: Sequence[str] = None,
type_promoting_args: Optional[Sequence[str]] = None,
):
self.type_promoting_arg_names = type_promoting_args
self.type_promotion_kind = type_promotion_kind

View File

@ -1832,7 +1832,7 @@ def clamp(
@out_wrapper()
def clamp_min(
self: TensorLikeType,
min: TensorOrNumberLikeType = None,
min: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
return torch.clamp(self, min=min) # type: ignore[arg-type]
@ -1841,7 +1841,7 @@ def clamp_min(
@out_wrapper()
def clamp_max(
self: TensorLikeType,
max: TensorOrNumberLikeType = None,
max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
return torch.clamp(self, max=max) # type: ignore[arg-type]
@ -4654,7 +4654,7 @@ def logspace(
ret = torch.linspace(
start,
end,
steps,
steps, # type: ignore[arg-type]
dtype=torch.float64,
layout=layout,
device=device,

View File

@ -63,7 +63,7 @@ class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: int = None, vdim: int = None, batch_first: bool = False,
kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(embed_dim, num_heads, dropout,

View File

@ -856,7 +856,7 @@ def prepare_n_shadows_model(
create_n_transformed_and_logged_copies_of_subgraph(
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
custom_prepare_fn, custom_prepare_kwargs
custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
)
return mt

View File

@ -443,7 +443,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
example_inputs: Any,
last_added_shadow_node_list: List[Optional[Node]],
custom_prepare_fn: Optional[Callable] = None,
custom_prepare_kwargs: Dict[str, Any] = None,
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Given a subgraph in `mt` and a subgraph candidate idx, inserts the
@ -575,7 +575,7 @@ def create_n_transformed_and_logged_copies_of_subgraph(
qconfig_mappings: List[QConfigMapping],
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
custom_prepare_fn: Optional[Callable] = None,
custom_prepare_kwargs: Dict[str, Any] = None,
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Given a model `mt` and a subgraph_idx, creates the needed copies
@ -756,7 +756,7 @@ def create_add_loggers_graph(
create_n_transformed_and_logged_copies_of_subgraph(
model, cur_subgraph_idx, match_name, maybe_subgraph,
[qconfig_mapping], [node_name_to_qconfig],
None, None
None, None # type: ignore[arg-type]
)
# find the created shadow module and record it so we
# can find it easily in step 2

View File

@ -78,7 +78,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
res.append(param_value)
return res
else:
assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
res = []
for weight_value in mod._all_weight_values:
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])

View File

@ -1,4 +1,4 @@
from typing import Dict, Any, List
from typing import Any, Dict, List, Optional
import torch
from collections import defaultdict
from torch import nn
@ -205,7 +205,7 @@ class ActivationSparsifier:
# or sparsify_hook()
self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached
def get_mask(self, name: str = None, layer: nn.Module = None):
def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None):
"""
Returns mask associated to the layer.

View File

@ -3,7 +3,7 @@ import torch
from dlrm_s_pytorch import unpack_batch # type: ignore[import]
import numpy as np # type: ignore[import]
import sklearn # type: ignore[import]
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model # type: ignore[import]
import pandas as pd # type: ignore[import]
import argparse

View File

@ -1,7 +1,7 @@
import torch
from torch.nn import functional as F
from functools import reduce
from typing import Tuple, Any, List
from typing import Any, List, Optional, Tuple
from .base_data_sparsifier import BaseDataSparsifier
@ -31,9 +31,9 @@ class DataNormSparsifier(BaseDataSparsifier):
arguments and could be overriden by the configuration provided in the
`add_data` step.
"""
def __init__(self, data_list: List[Tuple[str, Any]] = None, sparsity_level: float = 0.5,
def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, sparsity_level: float = 0.5,
sparse_block_shape: Tuple[int, int] = (1, 4),
zeros_per_block: int = None, norm: str = 'L1'):
zeros_per_block: Optional[int] = None, norm: str = 'L1'):
if zeros_per_block is None:
zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
from typing import Dict, List
from typing import Dict, List, Optional
SUPPORTED_MODULES = {
nn.Embedding,
@ -28,7 +28,7 @@ def _fetch_all_embeddings(model):
def post_training_sparse_quantize(model,
data_sparsifier_class,
sparsify_first=True,
select_embeddings: List[nn.Module] = None,
select_embeddings: Optional[List[nn.Module]] = None,
**sparse_config):
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.

View File

@ -44,7 +44,7 @@ class APoTQuantizer():
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment]
return result
@ -83,7 +83,7 @@ class APoTQuantizer():
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
levels_lst = list(self.quantization_levels)
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg]
return result

View File

@ -85,9 +85,9 @@ def _find_matches(
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Type] = None,
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
standalone_module_names: Optional[List[str]] = None,
standalone_module_classes: Optional[List[Type]] = None,
custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.

View File

@ -18,7 +18,7 @@ from torch.ao.quantization.utils import (
)
from abc import ABC
from typing import Callable, Dict, List, Type
from typing import Callable, Dict, List, Type, Optional
__all__ = [
"QuantizeHandler",
@ -52,7 +52,7 @@ class QuantizeHandler(ABC):
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None,
root_node_getter: Optional[Callable] = None,
is_custom_module=False,
is_standalone_module=False):
""" Records pattern information in __init__, which will be used
@ -113,7 +113,7 @@ def _get_quantize_handler_cls(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Callable = None):
root_node_getter: Optional[Callable] = None):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, \

View File

@ -458,8 +458,10 @@ def _update_special_qspecs_after_replacement(
if isinstance(edge_or_node, Node):
_node = edge_or_node
return original_to_replacement_node.get(_node, _node)
elif isinstance(edge_or_node, Tuple[Node, Node]):
src, dest = edge_or_node
# TODO: It's really should be
# isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node)
elif isinstance(edge_or_node, Tuple[Node, Node]): # type: ignore[arg-type]
src, dest = edge_or_node # type: ignore[misc]
return (
original_to_replacement_node.get(src, src),
original_to_replacement_node.get(dest, dest),

View File

@ -58,13 +58,13 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
for conv_op, add_op, relu_op in conv_add_relu_options:
if add_op is None:
# Append Conv ReLU
supported_operators["conv2d"].append([conv_op, relu_op])
supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item]
elif relu_op is None:
# Append Conv Add
supported_operators["conv2d"].append([conv_op, add_op])
supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item]
else:
# Append Conv Add ReLU
supported_operators["conv2d"].append([conv_op, add_op, relu_op])
supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item]
return copy.deepcopy(supported_operators)
@ -222,17 +222,17 @@ class X86InductorQuantizer(Quantizer):
"""
conv_gemm_node_idx = None
extra_input_node_idx = None
if (binary_node.args[0].op == "call_function") and (
if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr]
binary_node.args[0] == conv_gemm_node
):
conv_gemm_node_idx = 0
extra_input_node_idx = 1
elif (binary_node.args[1].op == "call_function") and (
elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr]
binary_node.args[1] == conv_gemm_node
):
conv_gemm_node_idx = 1
extra_input_node_idx = 0
extra_input_node = binary_node.args[extra_input_node_idx]
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
assert isinstance(extra_input_node, Node)
return conv_gemm_node_idx, extra_input_node_idx

View File

@ -119,9 +119,9 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
remove_tensor_overload_for_qdq_ops(model)
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
pattern = get_aten_graph_module(pattern, example_inputs)
remove_tensor_overload_for_qdq_ops(pattern)
replacement = get_aten_graph_module(replacement, example_inputs)
remove_tensor_overload_for_qdq_ops(replacement)
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
matches = replace_pattern(model, pattern, replacement)
return model

View File

@ -455,7 +455,7 @@ def _profile_to_snapshot(profile):
device = to_device(tensor_key.device)
addr = tensor_key.storage.ptr
seg = snapshot['segments'][device]
seg = snapshot['segments'][device] # type: ignore[index]
if seg['address'] is None or seg['address'] > addr:
seg['address'] = addr
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
@ -465,12 +465,12 @@ def _profile_to_snapshot(profile):
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
if during_trace:
snapshot['device_traces'][device].append(r)
snapshot['device_traces'][device].append(r) # type: ignore[index]
return r
def free(alloc, device):
for e in ('free_requested', 'free_completed'):
snapshot['device_traces'][device].append({'action': e,
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
'addr': alloc['addr'],
'size': alloc['size'],
'stream': 0,
@ -499,7 +499,7 @@ def _profile_to_snapshot(profile):
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
for (tensor_key, version), event in kv_to_elem.items()]
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
seg = snapshot['segments'][device]
seg = snapshot['segments'][device] # type: ignore[index]
last_addr = seg['address']
for _, addr, size, frames in blocks:
if last_addr < addr:
@ -510,8 +510,8 @@ def _profile_to_snapshot(profile):
if last_addr < seg['total_size']:
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]
for seg in snapshot['segments']:
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
seg['total_size'] -= seg['address']
if not seg['blocks']:
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})

View File

@ -1,5 +1,5 @@
import torch
from typing import cast, Iterable, List, Union
from typing import Iterable, List, Union
from . import _lazy_init, _lazy_call, device_count, current_device
from .. import Tensor
@ -56,7 +56,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu
device = torch.device('cuda', device)
def cb():
idx = cast(torch.device, device).index
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]

View File

@ -732,7 +732,7 @@ class ShardedTensor(ShardedTensorBase):
local_tensor: torch.Tensor,
sharding_spec: shard_spec.ShardingSpec,
*global_size: Sequence[int],
process_group: dist.ProcessGroup = None,
process_group: Optional[dist.ProcessGroup] = None,
init_rrefs=False,
) -> "ShardedTensor":
"""

View File

@ -389,7 +389,7 @@ def _compile(
# can trace operations applied to them.
def stateless_func(func, params, buffers, named_states, args, kwargs):
with stateless._reparametrize_module(
cast(nn.Module, mod), {**params, **buffers}
mod, {**params, **buffers}
), _rematerialize_optimizer(
opt, named_states, params
) if opt else nullcontext():

View File

@ -210,7 +210,7 @@ def full(
def zeros(
*size,
requires_grad: bool = False,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device_mesh: Optional[DeviceMesh] = None,
placements: Optional[Sequence[Placement]] = None,

View File

@ -133,7 +133,7 @@ def gen_model_param_in_submesh(model: nn.Module, sub_mesh: DeviceMesh) -> nn.Mod
)
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module:
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: # type: ignore[empty-body]
"""
checkpoint save/load models with DTensor parameters
"""

View File

@ -40,7 +40,7 @@ def register_op_strategy(op):
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.

View File

@ -1,6 +1,7 @@
import functools
import torch
import torch.distributed as dist
from typing import Optional
class DefaultState:
@ -127,7 +128,7 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch
allreduce_hook(state, grad)
_decompress(state, grad)
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
@ -144,7 +145,7 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: tor
fp16_hook = functools.partial(_low_precision_hook, torch.float16)
return fp16_hook(state, grad, output)
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
r"""
This FSDP communication hook implements a simple gradient compression
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).

View File

@ -20,7 +20,7 @@ def load_state_dict(
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: LoadPlanner = None,
planner: Optional[LoadPlanner] = None,
) -> None:
"""
Loads a distributed ``state_dict`` in SPMD style.

View File

@ -22,7 +22,7 @@ def save_state_dict(
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: SavePlanner = None,
planner: Optional[SavePlanner] = None,
) -> Metadata:
"""
Saves a distributed model in SPMD style.

View File

@ -58,7 +58,7 @@ def broadcast(
raise AssertionError("Data or Function is expected to be None if not successful")
payload: Optional[T] = None
exception : Exception = None
exception : Optional[Exception] = None
# if no pg is passed then execute if rank is 0
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
# determine if it is an executable function or data payload only
@ -119,7 +119,7 @@ def all_gather(
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
"""
payload: Optional[T] = None
exception : Exception = None
exception : Optional[Exception] = None
success = True
# determine if it is an executable function or data payload only
if callable(data_or_fn):
@ -143,7 +143,7 @@ def all_gather(
total_list = [None] * dist.get_world_size(pg)
all_gather_object_enforce_type(pg, total_list, sync_obj)
# Each rank will throw RuntimeError in case of failure on any rank.
stage_name: Optional[str] = cast(SyncPayload[T], total_list[0]).stage_name
stage_name = cast(SyncPayload[T], total_list[0]).stage_name
exception_list: List[Tuple[int, Exception]] = []
ret_list: List[T] = []
error_msg: str = ""
@ -160,7 +160,7 @@ def all_gather(
ret_list.append(sp.payload)
if len(exception_list) > 0:
raise RuntimeError(
raise RuntimeError( # type: ignore[misc]
error_msg, exception_list) from exception_list[0]
return ret_list
else:
@ -168,7 +168,7 @@ def all_gather(
raise RuntimeError(
f"all_gather failed with exception {sync_obj.exception}",
) from sync_obj.exception
return [sync_obj.payload]
return [sync_obj.payload] # type: ignore[list-item]
# Note: use Any for typing for now so users can pass in

View File

@ -68,7 +68,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: str = None):
def configure(handler: MetricHandler, group: Optional[str] = None):
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used

View File

@ -158,7 +158,7 @@ class FileTimerServer:
file_path: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Callable[[str, Optional[FileTimerRequest]], None] = None
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
) -> None:
self._file_path = file_path
self._max_interval = max_interval

View File

@ -476,7 +476,7 @@ def _flatten_optim_state_dict(
len(fqns) == 1
), f"use_orig_params is True but there are multiple FQNs, {fqns}."
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
state = optim.state.get(param, None)
state = optim.state.get(param, None) # type: ignore[call-overload]
if state is not None:
flat_osd_state[key] = copy.deepcopy(state)
else:

View File

@ -1433,7 +1433,7 @@ def _register_post_backward_reshard_only_hooks(
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
)
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined]
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
@no_type_check
@ -1468,11 +1468,11 @@ def _wait_for_computation_stream(
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
unshard_stream.wait_stream(computation_stream)
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream)
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
def _reset_flat_param_grad_info_if_needed(

View File

@ -167,7 +167,7 @@ class _ExecOrderTracer:
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
) -> torch.fx.Proxy:
"""
Overrides ``create_proxy`` to save execution information to

View File

@ -96,7 +96,7 @@ def _auto_wrap(
)
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
def _check_nested_wrapping(root_module: nn.Module):

View File

@ -129,8 +129,8 @@ class _RemoteModule(nn.Module):
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Tuple = None,
kwargs: Dict[str, Any] = None,
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
_module_interface_cls: Any = None,
):
"""
@ -358,7 +358,7 @@ class _RemoteModule(nn.Module):
def bfloat16(self: T) -> T: # type: ignore[return]
_raise_not_supported(self.bfloat16.__name__)
def to(self, *args, **kwargs) -> T: # type: ignore[return]
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
_raise_not_supported(self.to.__name__)
def register_backward_hook( # type: ignore[return]
@ -377,7 +377,7 @@ class _RemoteModule(nn.Module):
) -> RemovableHandle:
_raise_not_supported(self.register_forward_pre_hook.__name__)
def register_forward_hook( # type: ignore[return]
def register_forward_hook( # type: ignore[return, override]
self,
hook: Union[
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
@ -685,8 +685,8 @@ class RemoteModule(_RemoteModule):
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Tuple = None,
kwargs: Dict[str, Any] = None,
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(remote_device, module_cls, args, kwargs)

View File

@ -2,7 +2,7 @@ import logging
import warnings
from copy import deepcopy
from typing import Any, Collection, Dict, List, Mapping, Union
from typing import Any, Collection, Dict, List, Mapping, Optional, Union
import torch
import torch.nn as nn
@ -63,8 +63,8 @@ class _NamedOptimizer(optim.Optimizer):
self,
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
optimizer_class: optim.Optimizer,
param_groups: Collection[Mapping[str, Any]] = None,
module: nn.Module = None,
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
module: Optional[nn.Module] = None,
*args,
**kwargs,
) -> None:
@ -154,7 +154,7 @@ class _NamedOptimizer(optim.Optimizer):
self._optimizer.step(closure=closure)
@property
def state(self) -> Mapping[torch.Tensor, Any]:
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
return self._optimizer.state
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:

View File

@ -1,12 +1,12 @@
from torch import nn
from typing import List
from typing import List, Optional
__all__ = ["partition_model"]
def partition_model(
module: nn.Sequential,
balance: List[int],
devices: List[int] = None):
devices: Optional[List[int]] = None):
"""
Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
the model across multiple GPU devices according the provided ``balance``

View File

@ -49,7 +49,7 @@ def input_reshard(
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
nonlocal cx
cx.__exit__() # type: ignore[name-defined]
cx.__exit__() # type: ignore[name-defined, union-attr]
if input_reshard_dim is None:
return module

View File

@ -1,7 +1,7 @@
import math
import warnings
from numbers import Number
from typing import Union
from typing import Optional, Union
import torch
from torch import nan
@ -72,9 +72,9 @@ class Wishart(ExponentialFamily):
def __init__(self,
df: Union[torch.Tensor, Number],
covariance_matrix: torch.Tensor = None,
precision_matrix: torch.Tensor = None,
scale_tril: torch.Tensor = None,
covariance_matrix: Optional[torch.Tensor] = None,
precision_matrix: Optional[torch.Tensor] = None,
scale_tril: Optional[torch.Tensor] = None,
validate_args=None):
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."

View File

@ -23,7 +23,7 @@ class FoldedGraphModule(torch.fx.GraphModule):
root: torch.nn.Module,
graph: torch.fx.Graph,
const_subgraph: Optional[torch.fx.Graph] = None,
fx_const_folded_attrs_name: str = None,
fx_const_folded_attrs_name: Optional[str] = None,
device_for_folded_attrs: str = "cuda",
):
super().__init__(root, graph)

View File

@ -259,7 +259,7 @@ class MetaTracer(torch.fx.Tracer):
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
meta_args : Dict[str, torch.Tensor] = None,
meta_args : Optional[Dict[str, torch.Tensor]] = None,
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
tracer = MetaTracer()
graph = tracer.trace(root, meta_args, concrete_args)

View File

@ -1,4 +1,4 @@
# type: ignore[attr-defined]
# mypy: disable-error-code=attr-defined
from .core import unify, reify # noqa: F403
from .more import unifiable # noqa: F403
from .variable import var, isvar, vars, variables, Var # noqa: F403

View File

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class Partition:
def __init__(self, id: int = None, nodes: Iterable[Node] = None):
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
self.id = id
self.nodes: Set[Node] = set(nodes) if nodes is not None else set()

View File

@ -1,6 +1,6 @@
from functools import wraps
from inspect import unwrap
from typing import Callable, List
from typing import Callable, List, Optional
import logging
logger = logging.getLogger(__name__)
@ -76,7 +76,7 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None):
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
"""
Convenience wrapper for passes which need to be applied multiple times.

View File

@ -202,7 +202,7 @@ def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
"""
@ -222,7 +222,7 @@ def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:

View File

@ -1302,7 +1302,7 @@ def amin(
@_apply_docstring_templates
def argmax(
input: Union[Tensor, MaskedTensor],
dim: int = None,
dim: Optional[int] = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
@ -1328,7 +1328,7 @@ def argmax(
@_apply_docstring_templates
def argmin(
input: Union[Tensor, MaskedTensor],
dim: int = None,
dim: Optional[int] = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,

View File

@ -103,7 +103,7 @@ class Sequential(Module):
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)

View File

@ -51,7 +51,7 @@ class _ConvNd(Module):
'out_channels', 'kernel_size']
__annotations__ = {'bias': Optional[torch.Tensor]}
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
...
in_channels: int

View File

@ -1,7 +1,7 @@
import contextlib
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterator, Set, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union
import torch
from torch import Tensor
@ -148,7 +148,7 @@ def functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
@ -233,7 +233,7 @@ def _functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,

View File

@ -340,7 +340,7 @@ class ProtobufExportOutputSerializer:
) -> None:
import onnx
if not isinstance(export_output.model_proto, onnx.ModelProto):
if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
raise ValueError("export_output.ModelProto is not an onnx.ModelProto")
destination.write(export_output.model_proto.SerializeToString())
@ -348,7 +348,7 @@ class ProtobufExportOutputSerializer:
class ExportOutput:
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""
_model_proto: Final[onnx.ModelProto]
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined]
_input_adapter: Final[io_adapter.InputAdapter]
_output_adapter: Final[io_adapter.OutputAdapter]
_diagnostic_context: Final[infra.DiagnosticContext]
@ -356,7 +356,7 @@ class ExportOutput:
@_beartype.beartype
def __init__(
self,
model_proto: onnx.ModelProto,
model_proto: onnx.ModelProto, # type: ignore[name-defined]
input_adapter: io_adapter.InputAdapter,
output_adapter: io_adapter.OutputAdapter,
diagnostic_context: infra.DiagnosticContext,
@ -367,7 +367,7 @@ class ExportOutput:
self._diagnostic_context = diagnostic_context
@property
def model_proto(self) -> onnx.ModelProto:
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
"""The exported ONNX model as an ``onnx.ModelProto``."""
return self._model_proto

View File

@ -193,7 +193,7 @@ class DynamoExport(exporter.FXGraphExtractor):
wrapped_model,
*model_args,
tracing_mode=fx_mode,
fake_mode=fake_mode,
fake_mode=fake_mode, # type: ignore[arg-type]
**model_kwargs,
)
del graph_guard # Unused

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
@_beartype.beartype
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor, name: str, location: str, basepath: str
) -> onnx.TensorProto:
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
@ -38,13 +38,13 @@ def _create_tensor_proto_with_external_data(
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
tensor_proto = onnx.TensorProto()
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
tensor.dtype
).onnx_type()
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
@ -86,7 +86,7 @@ def save_model_with_external_data(
model_location: str,
initializer_location: str,
torch_load_paths: Tuple[Union[str, io.BytesIO], ...],
onnx_model: onnx.ModelProto,
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
@ -121,7 +121,7 @@ def save_model_with_external_data(
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
onnx_model_with_initializers = onnx.ModelProto()
onnx_model_with_initializers = onnx.ModelProto() # type: ignore[attr-defined]
onnx_model_with_initializers.CopyFrom(onnx_model)
onnx_input_names = [input.name for input in onnx_model.graph.input]
@ -159,4 +159,4 @@ def save_model_with_external_data(
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) # type: ignore[attr-defined]

View File

@ -62,7 +62,7 @@ def export_as_test_case(
shutil.rmtree(data_set_dir)
os.makedirs(data_set_dir)
proto = onnx.load_model_from_string(model_bytes)
proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
@ -112,12 +112,12 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
inputs = {}
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
for input_file in input_files:
tensor = onnx.load_tensor(input_file)
tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
inputs[tensor.name] = numpy_helper.to_array(tensor)
outputs = {}
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
for output_file in output_files:
tensor = onnx.load_tensor(output_file)
tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
outputs[tensor.name] = numpy_helper.to_array(tensor)
return model_bytes, inputs, outputs
@ -227,7 +227,7 @@ def _add_onnxscript_fn(
# size > 2GB, and if it for some reason did not, the model would fail on
# serialization anyway in terms of the protobuf limitation. So we don't
# need to worry about > 2GB model getting here.
model_proto = onnx.load_model_from_string(model_bytes)
model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
# Iterate graph nodes to insert only the included custom
# function_proto into model_proto

View File

@ -206,14 +206,14 @@ def _ort_session(
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
try:
import onnx
from onnx import reference as onnx_reference
from onnx import reference as onnx_reference # type: ignore[attr-defined]
except ImportError:
raise ImportError("onnx >= 1.13 is required for reference evaluator.")
proto = (
onnx.load(model)
onnx.load(model) # type: ignore[attr-defined]
if isinstance(model, str)
else onnx.load_model_from_string(model.getvalue())
else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined]
)
onnx_session = onnx_reference.ReferenceEvaluator(proto)
return onnx_session

View File

@ -6,7 +6,7 @@ import warnings
import functools
import math
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple, Optional
from torch import Tensor
import torch.utils.hooks as hooks
@ -208,7 +208,7 @@ class Optimizer:
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.state: Dict[int, Any] = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
@ -340,8 +340,8 @@ class Optimizer:
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
hooked = getattr(self.__class__.step, "hooked", None)
if not hooked:
self.__class__.step = self.profile_hook_step(self.__class__.step)
self.__class__.step.hooked = True
self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[method-assign]
self.__class__.step.hooked = True # type: ignore[attr-defined]
def register_step_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
r"""Register an optimizer step pre hook which will be called before
@ -418,14 +418,15 @@ class Optimizer:
}
@staticmethod
def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: int = None,
param_groups: List[Dict[Any, Any]] = None, key=None) -> Tensor:
def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: Optional[int] = None,
param_groups: Optional[List[Dict[Any, Any]]] = None, key=None) -> Tensor:
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
# UNLESS fused or capturable, see note [special device hosting for step]
fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
@ -477,14 +478,14 @@ class Optimizer:
elif isinstance(value, dict):
return {k: cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
state: Dict[Any, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
@ -521,7 +522,7 @@ class Optimizer:
if not hasattr(self, "_zero_grad_profile_name"):
self._patch_step_function()
if foreach:
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
per_device_and_dtype_grads: Dict[Any, Dict[Any, List[Any]]] = defaultdict(lambda: defaultdict(list))
with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
for group in self.param_groups:
for p in group['params']:

View File

@ -603,7 +603,7 @@ def input_dtypes(event: _ProfilerEvent):
def report_all_anti_patterns(prof,
should_benchmark: bool = False,
print_enable: bool = True,
json_report_dir: str = None):
json_report_dir: Optional[str] = None):
report_dict: Dict = {}
anti_patterns = [
ExtraCUDACopyPattern(prof, should_benchmark),

View File

@ -242,7 +242,7 @@ class _KinetoProfile:
assert self.profiler is not None and self.profiler.kineto_results is not None
return MemoryProfile(self.profiler.kineto_results)
def export_memory_timeline(self, path: str, device: str = None) -> None:
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
"""Extract the memory information from the memory profile collected
tree for a given device, and export a timeline plot consisting of
[times, [sizes by category]], where times are timestamps and sizes

View File

@ -861,7 +861,7 @@ def load(
pickle_module: Any = None,
*,
weights_only: bool = False,
mmap: bool = None,
mmap: Optional[bool] = None,
**pickle_load_args: Any
) -> Any:
# Reference: https://github.com/pytorch/pytorch/issues/54354

View File

@ -422,9 +422,9 @@ Examples::
def hamming(M: int,
*,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
return general_hamming(M, sym=sym, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
@ -469,9 +469,9 @@ Examples::
def hann(M: int,
*,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
return general_hamming(M,
alpha=0.5,
@ -521,9 +521,9 @@ Examples::
def blackman(M: int,
*,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()
@ -575,9 +575,9 @@ Examples::
def bartlett(M: int,
*,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()
@ -644,9 +644,9 @@ Examples::
def general_cosine(M, *,
a: Iterable,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
if dtype is None:
dtype = torch.get_default_dtype()
@ -721,9 +721,9 @@ def general_hamming(M,
*,
alpha: float = 0.54,
sym: bool = True,
dtype: torch.dtype = None,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: torch.device = None,
device: Optional[torch.device] = None,
requires_grad: bool = False) -> Tensor:
return general_cosine(M,
a=[alpha, 1. - alpha],

View File

@ -199,7 +199,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
self.compressed_tensor = compressed_tensor
self.transposed = transposed
def __repr__(self) -> str:
def __repr__(self) -> str: # type: ignore[override]
"""Return string representation of SparseSemiStructuredTensor
Returns:

View File

@ -3,7 +3,7 @@ import io
import torch
from ._utils import _type, _cuda, _hpu
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict
from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union
import copy
import collections
from functools import lru_cache
@ -27,51 +27,51 @@ class _StorageBase:
device: torch.device
def __init__(self, *args, **kwargs): ... # noqa: E704
def __len__(self) -> int: ... # noqa: E704
def __len__(self) -> int: ... # type: ignore[empty-body] # noqa: E704
def __getitem__(self, idx): ... # noqa: E704
def __setitem__(self, *args, **kwargs): ... # noqa: E704
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
def new(self) -> T: ... # noqa: E704
def nbytes(self) -> int: ... # noqa: E704
def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ... # type: ignore[empty-body] # noqa: E704
def new(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def nbytes(self) -> int: ... # type: ignore[empty-body] # noqa: E704
def size(self) -> int:
return self.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def element_size(self) -> int: ... # noqa: E704
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704
def get_device(self) -> int:
return self.device.index
def data_ptr(self) -> int: ... # noqa: E704
def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704
# Defined in torch/csrc/generic/StorageSharing.cpp
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
def from_buffer(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_decref(self) -> T: ... # noqa: E704
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _write_file(self, *args, **kwargs): ... # noqa: E704
def resize_(self, size: int): ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
def is_shared(self) -> bool: ... # noqa: E704
def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
@classmethod
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
@ -80,9 +80,9 @@ class _StorageBase:
@property
def is_hpu(self): ... # noqa: E704
@classmethod
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
def from_file(cls, filename, shared, nbytes) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
@classmethod
def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
def _expired(cls, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
def _byteswap(self, *args, **kwargs): ... # noqa: E704
def __str__(self):
@ -712,12 +712,12 @@ class TypedStorage:
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
return tmp_tensor[idx_wrapped].item()
def copy_(self, source: T, non_blocking: bool = None):
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
_warn_typed_storage_removal()
if isinstance(source, TypedStorage):
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
else:
self._untyped_storage.copy_(source, non_blocking)
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
return self
def nbytes(self):
@ -728,7 +728,7 @@ class TypedStorage:
def _nbytes(self):
return self._untyped_storage.nbytes()
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]:
_warn_typed_storage_removal()
if dtype is None:
legacy_class = self._get_legacy_storage_class()
@ -741,14 +741,14 @@ class TypedStorage:
else:
return self._untyped_storage.type(dtype, non_blocking)
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create HPU storage with quantized dtype")

View File

@ -43,37 +43,37 @@ class Storage:
dtype: torch.dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo) -> 'Storage':
def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body]
...
def _new_shared(self, int) -> 'Storage':
def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body]
...
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
...
def element_size(self) -> int:
def element_size(self) -> int: # type: ignore[empty-body]
...
def is_shared(self) -> bool:
def is_shared(self) -> bool: # type: ignore[empty-body]
...
def share_memory_(self) -> 'Storage':
def share_memory_(self) -> 'Storage': # type: ignore[empty-body]
...
def nbytes(self) -> int:
def nbytes(self) -> int: # type: ignore[empty-body]
...
def cpu(self) -> 'Storage':
def cpu(self) -> 'Storage': # type: ignore[empty-body]
...
def data_ptr(self) -> int:
def data_ptr(self) -> int: # type: ignore[empty-body]
...
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body]
...
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
...
...

View File

@ -100,7 +100,7 @@ def report_compile_source_on_error():
# specifically _PyCode_InitAddressRange, reveals that
# this iterator is initialized from co_linetable and
# co_firstfileno. So copy these we must!
code = code.replace(
code = code.replace( # type: ignore[call-arg]
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
)

View File

@ -184,7 +184,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
unsupported_dtype: List[torch.dtype] = None) -> None:
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
# Attribute is registered in the _StorageBase class
# and UntypedStorage obtains through inheritance.
@property # type: ignore[misc]
@ -255,7 +255,7 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
for_storage: bool = False,
unsupported_dtype: List[torch.dtype] = None) -> None:
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
r"""
generate_methods_for_privateuse1_backend(for_tensor, for_module, for_storage, unsupported_dtype) -> None

View File

@ -37,8 +37,8 @@ if HAS_TABULATE:
model: Union[torch.nn.Module, Callable],
sample_input: Union[torch.Tensor, Any],
num_iters: int = 5,
optimizer: torch.optim.Optimizer = None,
loss_fn: Callable = None,
optimizer: Optional[torch.optim.Optimizer] = None,
loss_fn: Optional[Callable] = None,
):
# Define the statement and setup for the benchmark
if optimizer and loss_fn:
@ -73,8 +73,8 @@ if HAS_TABULATE:
num_iters: int = 5,
backend: Optional[str] = None,
mode: Optional[str] = "default",
optimizer: torch.optim.Optimizer = None,
loss_fn : Union[torch.nn.Module, Callable] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
loss_fn : Union[torch.nn.Module, Callable, None] = None,
):
"""
Use this utility to benchmark torch.compile
@ -117,7 +117,7 @@ if HAS_TABULATE:
sample_input: Union[torch.Tensor, Any],
num_iters : int = 5,
optimizer: Optional[torch.optim.Optimizer] = None,
loss_fn : Union[torch.nn.Module, Callable] = None,
loss_fn : Union[torch.nn.Module, Callable, None] = None,
):
"""
This is a simple utility that can be used to benchmark torch.compile

View File

@ -1066,7 +1066,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
self._data_queue = self._worker_result_queue # type: ignore[assignment]
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked

View File

@ -415,7 +415,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
@functional_datapipe('trace_as_dataframe')
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
source_datapipe = None
# TODO(VitalyFedyunin): Must implement all special functions of datapipes

View File

@ -376,7 +376,7 @@ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataP
self._datapipe_iter = iter(self._datapipe)
return self
def __next__(self) -> T_co:
def __next__(self) -> T_co: # type: ignore[type-var]
assert self._datapipe_iter is not None
return next(self._datapipe_iter)

View File

@ -421,7 +421,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
self.main_datapipe_exhausted = False
self._child_stop: List[bool] = [True for _ in range(num_instances)]
def _find_next(self, instance_id: int) -> T_co:
def _find_next(self, instance_id: int) -> T_co: # type: ignore[type-var]
while True:
if self.main_datapipe_exhausted or self._child_stop[instance_id]:
raise StopIteration

View File

@ -40,7 +40,7 @@ class ConcaterMapDataPipe(MapDataPipe):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes # type: ignore[assignment]
def __getitem__(self, index) -> T_co:
def __getitem__(self, index) -> T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
if index - offset < len(dp):

View File

@ -418,5 +418,5 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]

View File

@ -236,7 +236,7 @@ class FlopCounterMode(TorchDispatchMode):
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
depth: int = 2,
display: bool = True,
custom_mapping: Dict[Any, Any] = None):
custom_mapping: Optional[Dict[Any, Any]] = None):
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
self.depth = depth
self.parents = ["Global"]

View File

@ -1,6 +1,6 @@
import gc
import sys
from typing import NamedTuple, Tuple, List, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import types
import weakref
import json
@ -100,7 +100,7 @@ def annotated_references(obj):
need for a list. Descriptions are currently strings.
"""
references = {}
references: Dict[int, List[str]] = {}
def add_reference(name, obj):
references.setdefault(id(obj), []).append(name)
@ -272,7 +272,7 @@ def create_graph(objects, *, context=None, filter=None):
filter = is_cuda_tensor
nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
node_referrers = [[] for obj in objects]
node_referrers: List[List[int]] = [[] for obj in objects]
id_to_node = {id(obj): i for i, obj in enumerate(objects)}
for obj in objects:
@ -299,8 +299,8 @@ def create_graph(objects, *, context=None, filter=None):
to_keep.add(idx)
referrers = node_referrers[idx]
to_search.extend(referrers)
id_to_filtered_id = {}
filtered = []
id_to_filtered_id: Dict[int, int] = {}
filtered: List[Any] = []
for i, n in enumerate(nodes):
if i in to_keep:
id_to_filtered_id[i] = len(id_to_filtered_id)

View File

@ -50,7 +50,7 @@ class WeakIdRef(weakref.ref):
# cache the id of the key as we know this is definitely the hash
# method
self._id = id(key)
super().__init__(key, callback)
super().__init__(key, callback) # type: ignore[call-arg]
def __call__(self):
r = super().__call__()

Some files were not shown because too many files have changed in this diff Show More