mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands - [Typing] Fix PEP 484 Violation (#105022) - Update mypy to 1.4.1 (#91983) That were reverted due to the conflict with internal source repo. Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional) Plus few real fixes: - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi` - Add missing return statement to `torch._export. deserialize_graph` - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights` - Add assert it `torch/optim/optimizer.py` that Optional list is not None TODO (in followup PR): - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py` Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04: - Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh` - Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227 Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
5cd861fcf7
commit
5837e95d30
@ -160,6 +160,20 @@ case "$image" in
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9)
|
||||
CUDA_VERSION=12.1.0
|
||||
CUDNN_VERSION=8
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=9
|
||||
PROTOBUF=yes
|
||||
DB=yes
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-focal-py3-clang7-asan)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
CLANG_VERSION=7
|
||||
|
@ -105,5 +105,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
|
||||
pip_install -r /opt/conda/requirements-docs.txt
|
||||
fi
|
||||
|
||||
# HACK HACK HACK
|
||||
# gcc-9 for ubuntu-18.04 from http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu
|
||||
# Pulls llibstdc++6 13.1.0-8ubuntu1~18.04 which is too new for conda
|
||||
# So remove libstdc++6.so.3.29 installed by https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0
|
||||
if grep 18.04.6 /etc/issue >/dev/null; then
|
||||
rm /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/lib/libstdc++.so.6
|
||||
fi
|
||||
|
||||
popd
|
||||
fi
|
||||
|
@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
mypy==0.960
|
||||
mypy==1.4.1
|
||||
# Pin MyPy version because new errors are likely to appear with each release
|
||||
#Description: linter
|
||||
#Pinned versions: 0.960
|
||||
#Pinned versions: 1.4.1
|
||||
#test that import: test_typing.py, test_type_hints.py
|
||||
|
||||
networkx==2.8.8
|
||||
|
1
.github/workflows/docker-builds.yml
vendored
1
.github/workflows/docker-builds.yml
vendored
@ -40,6 +40,7 @@ jobs:
|
||||
- docker-image-name: pytorch-linux-bionic-cuda11.8-cudnn8-py3-gcc7-inductor-benchmarks
|
||||
- docker-image-name: pytorch-linux-bionic-py3.8-clang9
|
||||
- docker-image-name: pytorch-linux-bionic-py3.11-clang9
|
||||
- docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9
|
||||
- docker-image-name: pytorch-linux-focal-rocm-n-1-py3
|
||||
- docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
- docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12
|
||||
|
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -305,12 +305,12 @@ jobs:
|
||||
{ config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" },
|
||||
]}
|
||||
|
||||
linux-bionic-cuda12_1-py3_10-gcc9-bazel-test:
|
||||
name: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test
|
||||
linux-focal-cuda12_1-py3_10-gcc9-bazel-test:
|
||||
name: linux-focal-cuda12.1-py3.10-gcc9-bazel-test
|
||||
uses: ./.github/workflows/_bazel-build-test.yml
|
||||
with:
|
||||
build-environment: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test
|
||||
docker-image-name: pytorch-linux-bionic-cuda12.1-cudnn8-py3-gcc9
|
||||
build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test
|
||||
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9
|
||||
cuda-version: "12.1"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
|
@ -152,7 +152,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.24.3',
|
||||
'expecttest==0.1.3',
|
||||
'mypy==0.960',
|
||||
'mypy==1.4.1',
|
||||
'types-requests==2.27.25',
|
||||
'types-PyYAML==6.0.7',
|
||||
'types-tabulate==0.8.8',
|
||||
|
@ -49,7 +49,7 @@ try:
|
||||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
def write(filename, s):
|
||||
|
@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def):
|
||||
wall_time=step, step=step, graph_def=graph_def.SerializeToString())
|
||||
|
||||
|
||||
@cli.command("tensorboard-graphs")
|
||||
@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined]
|
||||
@click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False),
|
||||
multiple=True)
|
||||
@click.option("--tf-dir", type=click.Path(exists=True))
|
||||
@ -129,7 +129,7 @@ def tensorboard_graphs(c2_netdef, tf_dir):
|
||||
log.info("Wrote %s graphs to logdir %s", len(events), tf_dir)
|
||||
|
||||
|
||||
@cli.command("tensorboard-events")
|
||||
@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined]
|
||||
@click.option("--c2-dir", type=click.Path(exists=True, file_okay=False),
|
||||
help="Root directory of the Caffe2 run")
|
||||
@click.option("--tf-dir", type=click.Path(writable=True),
|
||||
@ -209,4 +209,4 @@ def tensorboard_events(c2_dir, tf_dir):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
cli() # type: ignore[misc]
|
||||
|
@ -21,7 +21,7 @@ class TestFuture(TestCase):
|
||||
error_msg = "Intentional Value Error"
|
||||
value_error = ValueError(error_msg)
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
# Set exception
|
||||
f.set_exception(value_error)
|
||||
# Exception should throw on wait
|
||||
@ -29,7 +29,7 @@ class TestFuture(TestCase):
|
||||
f.wait()
|
||||
|
||||
# Exception should also throw on value
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
f.set_exception(value_error)
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.value()
|
||||
@ -37,7 +37,7 @@ class TestFuture(TestCase):
|
||||
def cb(fut):
|
||||
fut.value()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
f.set_exception(value_error)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
@ -54,7 +54,7 @@ class TestFuture(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.wait()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
t = threading.Thread(target=wait_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error)
|
||||
@ -68,7 +68,7 @@ class TestFuture(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
fut.wait()
|
||||
|
||||
f = Future[T]()
|
||||
f = Future[T]() # type: ignore[valid-type]
|
||||
t = threading.Thread(target=then_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error)
|
||||
|
@ -29,7 +29,7 @@ from itertools import product, combinations, permutations
|
||||
from functools import partial
|
||||
from torch import multiprocessing as mp
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
|
||||
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
|
||||
IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest,
|
||||
@ -8490,7 +8490,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
except RuntimeError:
|
||||
pass
|
||||
with mp.Pool(1) as pool:
|
||||
out: list = pool.map(method, [arg])
|
||||
out = pool.map(method, [arg])
|
||||
self.assertTrue(out[0])
|
||||
|
||||
def _test_multinomial_invalid_probs(probs):
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
|
||||
from package.oss.cov_json import get_json_report
|
||||
from package.oss.init import initialization
|
||||
from package.tool.summarize_jsons import summarize_jsons
|
||||
from package.util.setting import TestPlatform
|
||||
from package.util.utils import print_time
|
||||
from package.oss.cov_json import get_json_report # type: ignore[import]
|
||||
from package.oss.init import initialization # type: ignore[import]
|
||||
from package.tool.summarize_jsons import summarize_jsons # type: ignore[import]
|
||||
from package.util.setting import TestPlatform # type: ignore[import]
|
||||
from package.util.utils import print_time # type: ignore[import]
|
||||
|
||||
|
||||
def report_coverage() -> None:
|
||||
|
@ -45,7 +45,7 @@ def transform_file_name(
|
||||
return file_path[file_path.find(folder) :]
|
||||
# remove pytorch base folder path
|
||||
if platform == TestPlatform.OSS:
|
||||
from package.oss.utils import get_pytorch_folder
|
||||
from package.oss.utils import get_pytorch_folder # type: ignore[import]
|
||||
|
||||
pytorch_foler = get_pytorch_folder()
|
||||
assert file_path.startswith(pytorch_foler)
|
||||
|
@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str:
|
||||
|
||||
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
|
||||
if platform == TestPlatform.OSS:
|
||||
from package.oss.utils import detect_compiler_type # type: ignore[misc]
|
||||
from package.oss.utils import ( # type: ignore[assignment, import, misc]
|
||||
detect_compiler_type,
|
||||
)
|
||||
|
||||
cov_type = detect_compiler_type() # type: ignore[call-arg]
|
||||
else:
|
||||
@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType:
|
||||
cov_type = detect_compiler_type()
|
||||
|
||||
check_compiler_type(cov_type)
|
||||
return cov_type
|
||||
return cov_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_test_name_from_whole_path(path: str) -> str:
|
||||
|
@ -20,7 +20,7 @@ from yaml.nodes import MappingNode
|
||||
try:
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader # type: ignore[misc]
|
||||
from yaml import Loader # type: ignore[assignment, misc]
|
||||
|
||||
H_NAME = "spv.h"
|
||||
CPP_NAME = "spv.cpp"
|
||||
|
@ -16,7 +16,7 @@ from yaml import dump, load
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
|
@ -11,7 +11,7 @@ from torchgen.selective_build.selector import SelectiveBuilder
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as Loader # type: ignore[misc]
|
||||
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
|
||||
|
@ -10,7 +10,7 @@ try:
|
||||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as YamlLoader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc]
|
||||
|
||||
NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
|
||||
TAGS_PATH = "aten/src/ATen/native/tags.yaml"
|
||||
|
@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ...
|
||||
def _last_executed_optimized_graph() -> Graph: ...
|
||||
def parse_type_comment(comment: str) -> Decl: ...
|
||||
def _get_upgraders_map_size() -> _int: ...
|
||||
def _get_upgraders_entry_map() -> Dict[str, str]: ...
|
||||
def _dump_upgraders_map() -> Dict[str, str]: ...
|
||||
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
|
||||
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
|
||||
|
@ -417,7 +417,7 @@ def sym_int(a):
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif isinstance(a, SymFloat):
|
||||
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type]
|
||||
return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload]
|
||||
return py_int(a) # type: ignore[operator]
|
||||
|
||||
def sym_max(a, b):
|
||||
@ -1320,7 +1320,7 @@ if TYPE_CHECKING:
|
||||
# Some type signatures pulled in from _VariableFunctions here clash with
|
||||
# signatures already imported. For now these clashes are ignored; see
|
||||
# PR #43339 for details.
|
||||
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
|
||||
from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403
|
||||
# Fixup segment_reduce visibility
|
||||
_segment_reduce = segment_reduce
|
||||
del segment_reduce
|
||||
|
@ -10,20 +10,20 @@ class _Union:
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
assert len(kwargs) == 1
|
||||
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
|
||||
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
|
||||
|
||||
def __post_init__(self):
|
||||
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1
|
||||
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
|
||||
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
|
||||
assert val is not None
|
||||
return val
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
|
||||
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type]
|
||||
assert val_type is not None
|
||||
return val_type
|
||||
|
||||
|
@ -9,7 +9,7 @@ import typing
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||
from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -749,7 +749,7 @@ class GraphModuleDeserializer:
|
||||
self.module = torch.nn.Module()
|
||||
|
||||
@contextmanager
|
||||
def save_graph_module(self) -> None:
|
||||
def save_graph_module(self) -> Iterator[None]:
|
||||
saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
@ -773,7 +773,7 @@ class GraphModuleDeserializer:
|
||||
|
||||
if vr := self.symbol_name_to_range.get(val.expr_str):
|
||||
symbolic_shapes._constrain_symbol_range(
|
||||
self.shape_env, sym, vr.lower, vr.upper
|
||||
self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return self.shape_env.create_symintnode(sym, hint=val.hint)
|
||||
@ -855,6 +855,7 @@ class GraphModuleDeserializer:
|
||||
output_node.meta["val"] = tuple(
|
||||
arg.meta["val"] for arg in output_node.args[0]
|
||||
)
|
||||
return output_node
|
||||
|
||||
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
|
||||
if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this.
|
||||
@ -1050,7 +1051,7 @@ class GraphModuleDeserializer:
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
|
||||
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
|
||||
ret = {}
|
||||
ret: Dict[str, Any] = {}
|
||||
if stack_trace := metadata.get("stack_trace"):
|
||||
ret["stack_trace"] = stack_trace
|
||||
|
||||
|
@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]:
|
||||
"""Getting upgraders entry map and operator version map and merge them into one dict."""
|
||||
upgraders = torch._C._get_upgraders_entry_map()
|
||||
op_version_map = torch._C._get_operator_version_map()
|
||||
output = defaultdict(tuple)
|
||||
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
|
||||
for opname, entry_list in op_version_map.items():
|
||||
if not entry_list:
|
||||
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,7 +12,7 @@ def functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
|
@ -1004,7 +1004,7 @@ def _linalg_svd_meta(
|
||||
A: Tensor,
|
||||
full_matrices: bool = False,
|
||||
compute_uv: bool = True,
|
||||
driver: str = None,
|
||||
driver: Optional[str] = None,
|
||||
):
|
||||
checkIsMatrix(A, "linalg.svd")
|
||||
checkFloatingOrComplex(A, "linalg.svd")
|
||||
@ -1147,7 +1147,7 @@ def linalg_solve_triangular_meta(
|
||||
upper: bool,
|
||||
left: bool = True,
|
||||
unitriangular: bool = False,
|
||||
out: Tensor = None,
|
||||
out: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
if out is None:
|
||||
out = A.new_empty([0])
|
||||
@ -4695,8 +4695,8 @@ def upsample_nearest2d_backward(
|
||||
grad_output: Tensor,
|
||||
output_size: Sequence[Union[int, torch.types.SymInt]],
|
||||
input_size: Sequence[Union[int, torch.types.SymInt]],
|
||||
scales_h: float = None,
|
||||
scales_w: float = None,
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
):
|
||||
full_output_size = upsample_common_check(
|
||||
input_size, output_size, num_spatial_dims=2
|
||||
|
@ -331,7 +331,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
|
||||
def _elementwise_meta(
|
||||
*args,
|
||||
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
|
||||
args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
|
||||
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
|
||||
) -> FakeTensor:
|
||||
"""
|
||||
Meta function for elementwise operations that produce outputs in the same dtype
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
@ -111,7 +111,7 @@ class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Dict = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -161,7 +161,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Dict = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
@ -374,7 +374,7 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Dict = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
@ -27,7 +27,7 @@ def load_tensor_reader(loc):
|
||||
|
||||
def register_debug_prims():
|
||||
@custom_op("debugprims::load_tensor")
|
||||
def load_tensor(
|
||||
def load_tensor( # type: ignore[empty-body]
|
||||
name: str,
|
||||
size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
|
@ -10,7 +10,7 @@ from torch._prims_common import (
|
||||
import torch._prims_common as utils
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from typing import Callable, Sequence, Tuple, NamedTuple, overload
|
||||
from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload
|
||||
import inspect
|
||||
from functools import wraps
|
||||
import warnings
|
||||
@ -97,7 +97,7 @@ class elementwise_type_promotion_wrapper:
|
||||
self,
|
||||
*,
|
||||
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
type_promoting_args: Sequence[str] = None,
|
||||
type_promoting_args: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.type_promoting_arg_names = type_promoting_args
|
||||
self.type_promotion_kind = type_promotion_kind
|
||||
|
@ -1832,7 +1832,7 @@ def clamp(
|
||||
@out_wrapper()
|
||||
def clamp_min(
|
||||
self: TensorLikeType,
|
||||
min: TensorOrNumberLikeType = None,
|
||||
min: Optional[TensorOrNumberLikeType] = None,
|
||||
) -> TensorLikeType:
|
||||
return torch.clamp(self, min=min) # type: ignore[arg-type]
|
||||
|
||||
@ -1841,7 +1841,7 @@ def clamp_min(
|
||||
@out_wrapper()
|
||||
def clamp_max(
|
||||
self: TensorLikeType,
|
||||
max: TensorOrNumberLikeType = None,
|
||||
max: Optional[TensorOrNumberLikeType] = None,
|
||||
) -> TensorLikeType:
|
||||
return torch.clamp(self, max=max) # type: ignore[arg-type]
|
||||
|
||||
@ -4654,7 +4654,7 @@ def logspace(
|
||||
ret = torch.linspace(
|
||||
start,
|
||||
end,
|
||||
steps,
|
||||
steps, # type: ignore[arg-type]
|
||||
dtype=torch.float64,
|
||||
layout=layout,
|
||||
device=device,
|
||||
|
@ -63,7 +63,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||
def __init__(self, embed_dim: int, num_heads: int,
|
||||
dropout: float = 0., bias: bool = True,
|
||||
add_bias_kv: bool = False, add_zero_attn: bool = False,
|
||||
kdim: int = None, vdim: int = None, batch_first: bool = False,
|
||||
kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(embed_dim, num_heads, dropout,
|
||||
|
@ -856,7 +856,7 @@ def prepare_n_shadows_model(
|
||||
create_n_transformed_and_logged_copies_of_subgraph(
|
||||
mt, subgraph_idx, match_name, nodes_in_this_subgraph,
|
||||
qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
|
||||
custom_prepare_fn, custom_prepare_kwargs
|
||||
custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mt
|
||||
|
@ -443,7 +443,7 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||
example_inputs: Any,
|
||||
last_added_shadow_node_list: List[Optional[Node]],
|
||||
custom_prepare_fn: Optional[Callable] = None,
|
||||
custom_prepare_kwargs: Dict[str, Any] = None,
|
||||
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given a subgraph in `mt` and a subgraph candidate idx, inserts the
|
||||
@ -575,7 +575,7 @@ def create_n_transformed_and_logged_copies_of_subgraph(
|
||||
qconfig_mappings: List[QConfigMapping],
|
||||
list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
|
||||
custom_prepare_fn: Optional[Callable] = None,
|
||||
custom_prepare_kwargs: Dict[str, Any] = None,
|
||||
custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given a model `mt` and a subgraph_idx, creates the needed copies
|
||||
@ -756,7 +756,7 @@ def create_add_loggers_graph(
|
||||
create_n_transformed_and_logged_copies_of_subgraph(
|
||||
model, cur_subgraph_idx, match_name, maybe_subgraph,
|
||||
[qconfig_mapping], [node_name_to_qconfig],
|
||||
None, None
|
||||
None, None # type: ignore[arg-type]
|
||||
)
|
||||
# find the created shadow module and record it so we
|
||||
# can find it easily in step 2
|
||||
|
@ -78,7 +78,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
|
||||
res.append(param_value)
|
||||
return res
|
||||
else:
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet"
|
||||
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
||||
res = []
|
||||
for weight_value in mod._all_weight_values:
|
||||
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Any, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from torch import nn
|
||||
@ -205,7 +205,7 @@ class ActivationSparsifier:
|
||||
# or sparsify_hook()
|
||||
self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached
|
||||
|
||||
def get_mask(self, name: str = None, layer: nn.Module = None):
|
||||
def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None):
|
||||
"""
|
||||
Returns mask associated to the layer.
|
||||
|
||||
|
@ -3,7 +3,7 @@ import torch
|
||||
from dlrm_s_pytorch import unpack_batch # type: ignore[import]
|
||||
import numpy as np # type: ignore[import]
|
||||
import sklearn # type: ignore[import]
|
||||
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model
|
||||
from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model # type: ignore[import]
|
||||
import pandas as pd # type: ignore[import]
|
||||
import argparse
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from functools import reduce
|
||||
from typing import Tuple, Any, List
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from .base_data_sparsifier import BaseDataSparsifier
|
||||
|
||||
@ -31,9 +31,9 @@ class DataNormSparsifier(BaseDataSparsifier):
|
||||
arguments and could be overriden by the configuration provided in the
|
||||
`add_data` step.
|
||||
"""
|
||||
def __init__(self, data_list: List[Tuple[str, Any]] = None, sparsity_level: float = 0.5,
|
||||
def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, sparsity_level: float = 0.5,
|
||||
sparse_block_shape: Tuple[int, int] = (1, 4),
|
||||
zeros_per_block: int = None, norm: str = 'L1'):
|
||||
zeros_per_block: Optional[int] = None, norm: str = 'L1'):
|
||||
if zeros_per_block is None:
|
||||
zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
SUPPORTED_MODULES = {
|
||||
nn.Embedding,
|
||||
@ -28,7 +28,7 @@ def _fetch_all_embeddings(model):
|
||||
def post_training_sparse_quantize(model,
|
||||
data_sparsifier_class,
|
||||
sparsify_first=True,
|
||||
select_embeddings: List[nn.Module] = None,
|
||||
select_embeddings: Optional[List[nn.Module]] = None,
|
||||
**sparse_config):
|
||||
"""Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
|
||||
The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
|
||||
|
@ -44,7 +44,7 @@ class APoTQuantizer():
|
||||
|
||||
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
|
||||
|
||||
result = TensorAPoT(self, tensor2quantize)
|
||||
result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment]
|
||||
|
||||
return result
|
||||
|
||||
@ -83,7 +83,7 @@ class APoTQuantizer():
|
||||
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
|
||||
levels_lst = list(self.quantization_levels)
|
||||
|
||||
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
|
||||
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg]
|
||||
|
||||
return result
|
||||
|
||||
|
@ -85,9 +85,9 @@ def _find_matches(
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
patterns: Dict[Pattern, QuantizeHandler],
|
||||
root_node_getter_mapping: Dict[Pattern, Callable],
|
||||
standalone_module_names: List[str] = None,
|
||||
standalone_module_classes: List[Type] = None,
|
||||
custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
|
||||
standalone_module_names: Optional[List[str]] = None,
|
||||
standalone_module_classes: Optional[List[Type]] = None,
|
||||
custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
|
||||
"""
|
||||
Matches the nodes in the input graph to quantization patterns, and
|
||||
outputs the information needed to quantize them in future steps.
|
||||
|
@ -18,7 +18,7 @@ from torch.ao.quantization.utils import (
|
||||
)
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, List, Type
|
||||
from typing import Callable, Dict, List, Type, Optional
|
||||
|
||||
__all__ = [
|
||||
"QuantizeHandler",
|
||||
@ -52,7 +52,7 @@ class QuantizeHandler(ABC):
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None,
|
||||
root_node_getter: Optional[Callable] = None,
|
||||
is_custom_module=False,
|
||||
is_standalone_module=False):
|
||||
""" Records pattern information in __init__, which will be used
|
||||
@ -113,7 +113,7 @@ def _get_quantize_handler_cls(
|
||||
self,
|
||||
node_pattern: NodePattern,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
root_node_getter: Callable = None):
|
||||
root_node_getter: Optional[Callable] = None):
|
||||
super().__init__(node_pattern, modules, root_node_getter)
|
||||
if num_tensor_args_to_observation_type:
|
||||
assert self.num_tensor_args in num_tensor_args_to_observation_type, \
|
||||
|
@ -458,8 +458,10 @@ def _update_special_qspecs_after_replacement(
|
||||
if isinstance(edge_or_node, Node):
|
||||
_node = edge_or_node
|
||||
return original_to_replacement_node.get(_node, _node)
|
||||
elif isinstance(edge_or_node, Tuple[Node, Node]):
|
||||
src, dest = edge_or_node
|
||||
# TODO: It's really should be
|
||||
# isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node)
|
||||
elif isinstance(edge_or_node, Tuple[Node, Node]): # type: ignore[arg-type]
|
||||
src, dest = edge_or_node # type: ignore[misc]
|
||||
return (
|
||||
original_to_replacement_node.get(src, src),
|
||||
original_to_replacement_node.get(dest, dest),
|
||||
|
@ -58,13 +58,13 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
for conv_op, add_op, relu_op in conv_add_relu_options:
|
||||
if add_op is None:
|
||||
# Append Conv ReLU
|
||||
supported_operators["conv2d"].append([conv_op, relu_op])
|
||||
supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item]
|
||||
elif relu_op is None:
|
||||
# Append Conv Add
|
||||
supported_operators["conv2d"].append([conv_op, add_op])
|
||||
supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item]
|
||||
else:
|
||||
# Append Conv Add ReLU
|
||||
supported_operators["conv2d"].append([conv_op, add_op, relu_op])
|
||||
supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item]
|
||||
|
||||
return copy.deepcopy(supported_operators)
|
||||
|
||||
@ -222,17 +222,17 @@ class X86InductorQuantizer(Quantizer):
|
||||
"""
|
||||
conv_gemm_node_idx = None
|
||||
extra_input_node_idx = None
|
||||
if (binary_node.args[0].op == "call_function") and (
|
||||
if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr]
|
||||
binary_node.args[0] == conv_gemm_node
|
||||
):
|
||||
conv_gemm_node_idx = 0
|
||||
extra_input_node_idx = 1
|
||||
elif (binary_node.args[1].op == "call_function") and (
|
||||
elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr]
|
||||
binary_node.args[1] == conv_gemm_node
|
||||
):
|
||||
conv_gemm_node_idx = 1
|
||||
extra_input_node_idx = 0
|
||||
extra_input_node = binary_node.args[extra_input_node_idx]
|
||||
extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index]
|
||||
assert isinstance(extra_input_node, Node)
|
||||
return conv_gemm_node_idx, extra_input_node_idx
|
||||
|
||||
|
@ -119,9 +119,9 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
||||
pattern = get_aten_graph_module(pattern, example_inputs)
|
||||
remove_tensor_overload_for_qdq_ops(pattern)
|
||||
replacement = get_aten_graph_module(replacement, example_inputs)
|
||||
remove_tensor_overload_for_qdq_ops(replacement)
|
||||
pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type]
|
||||
replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment]
|
||||
remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type]
|
||||
matches = replace_pattern(model, pattern, replacement)
|
||||
return model
|
||||
|
@ -455,7 +455,7 @@ def _profile_to_snapshot(profile):
|
||||
device = to_device(tensor_key.device)
|
||||
addr = tensor_key.storage.ptr
|
||||
|
||||
seg = snapshot['segments'][device]
|
||||
seg = snapshot['segments'][device] # type: ignore[index]
|
||||
if seg['address'] is None or seg['address'] > addr:
|
||||
seg['address'] = addr
|
||||
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
|
||||
@ -465,12 +465,12 @@ def _profile_to_snapshot(profile):
|
||||
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
|
||||
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
|
||||
if during_trace:
|
||||
snapshot['device_traces'][device].append(r)
|
||||
snapshot['device_traces'][device].append(r) # type: ignore[index]
|
||||
return r
|
||||
|
||||
def free(alloc, device):
|
||||
for e in ('free_requested', 'free_completed'):
|
||||
snapshot['device_traces'][device].append({'action': e,
|
||||
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
|
||||
'addr': alloc['addr'],
|
||||
'size': alloc['size'],
|
||||
'stream': 0,
|
||||
@ -499,7 +499,7 @@ def _profile_to_snapshot(profile):
|
||||
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
|
||||
for (tensor_key, version), event in kv_to_elem.items()]
|
||||
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
|
||||
seg = snapshot['segments'][device]
|
||||
seg = snapshot['segments'][device] # type: ignore[index]
|
||||
last_addr = seg['address']
|
||||
for _, addr, size, frames in blocks:
|
||||
if last_addr < addr:
|
||||
@ -510,8 +510,8 @@ def _profile_to_snapshot(profile):
|
||||
if last_addr < seg['total_size']:
|
||||
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
|
||||
|
||||
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']]
|
||||
for seg in snapshot['segments']:
|
||||
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
|
||||
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
|
||||
seg['total_size'] -= seg['address']
|
||||
if not seg['blocks']:
|
||||
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import cast, Iterable, List, Union
|
||||
from typing import Iterable, List, Union
|
||||
from . import _lazy_init, _lazy_call, device_count, current_device
|
||||
from .. import Tensor
|
||||
|
||||
@ -56,7 +56,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu
|
||||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = cast(torch.device, device).index
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
|
@ -732,7 +732,7 @@ class ShardedTensor(ShardedTensorBase):
|
||||
local_tensor: torch.Tensor,
|
||||
sharding_spec: shard_spec.ShardingSpec,
|
||||
*global_size: Sequence[int],
|
||||
process_group: dist.ProcessGroup = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
init_rrefs=False,
|
||||
) -> "ShardedTensor":
|
||||
"""
|
||||
|
@ -389,7 +389,7 @@ def _compile(
|
||||
# can trace operations applied to them.
|
||||
def stateless_func(func, params, buffers, named_states, args, kwargs):
|
||||
with stateless._reparametrize_module(
|
||||
cast(nn.Module, mod), {**params, **buffers}
|
||||
mod, {**params, **buffers}
|
||||
), _rematerialize_optimizer(
|
||||
opt, named_states, params
|
||||
) if opt else nullcontext():
|
||||
|
@ -210,7 +210,7 @@ def full(
|
||||
def zeros(
|
||||
*size,
|
||||
requires_grad: bool = False,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
placements: Optional[Sequence[Placement]] = None,
|
||||
|
@ -133,7 +133,7 @@ def gen_model_param_in_submesh(model: nn.Module, sub_mesh: DeviceMesh) -> nn.Mod
|
||||
)
|
||||
|
||||
|
||||
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module:
|
||||
def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: # type: ignore[empty-body]
|
||||
"""
|
||||
checkpoint save/load models with DTensor parameters
|
||||
"""
|
||||
|
@ -40,7 +40,7 @@ def register_op_strategy(op):
|
||||
def as_list(
|
||||
x: Union[List[object], object]
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:
|
||||
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||
# which is an object but treated as a list by the tracer. Therefore, keep
|
||||
# `immutable_list` intact here as well.
|
||||
|
@ -1,6 +1,7 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DefaultState:
|
||||
@ -127,7 +128,7 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch
|
||||
allreduce_hook(state, grad)
|
||||
_decompress(state, grad)
|
||||
|
||||
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
|
||||
def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
|
||||
r"""
|
||||
This FSDP communication hook implements a simple gradient compression
|
||||
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
|
||||
@ -144,7 +145,7 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: tor
|
||||
fp16_hook = functools.partial(_low_precision_hook, torch.float16)
|
||||
return fp16_hook(state, grad, output)
|
||||
|
||||
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None):
|
||||
def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None):
|
||||
r"""
|
||||
This FSDP communication hook implements a simple gradient compression
|
||||
approach that casts ``grad`` to half-precision floating-point format (``torch.float16``).
|
||||
|
@ -20,7 +20,7 @@ def load_state_dict(
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
no_dist: bool = False,
|
||||
planner: LoadPlanner = None,
|
||||
planner: Optional[LoadPlanner] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Loads a distributed ``state_dict`` in SPMD style.
|
||||
|
@ -22,7 +22,7 @@ def save_state_dict(
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
coordinator_rank: int = 0,
|
||||
no_dist: bool = False,
|
||||
planner: SavePlanner = None,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
) -> Metadata:
|
||||
"""
|
||||
Saves a distributed model in SPMD style.
|
||||
|
@ -58,7 +58,7 @@ def broadcast(
|
||||
raise AssertionError("Data or Function is expected to be None if not successful")
|
||||
|
||||
payload: Optional[T] = None
|
||||
exception : Exception = None
|
||||
exception : Optional[Exception] = None
|
||||
# if no pg is passed then execute if rank is 0
|
||||
if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
|
||||
# determine if it is an executable function or data payload only
|
||||
@ -119,7 +119,7 @@ def all_gather(
|
||||
>> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
|
||||
"""
|
||||
payload: Optional[T] = None
|
||||
exception : Exception = None
|
||||
exception : Optional[Exception] = None
|
||||
success = True
|
||||
# determine if it is an executable function or data payload only
|
||||
if callable(data_or_fn):
|
||||
@ -143,7 +143,7 @@ def all_gather(
|
||||
total_list = [None] * dist.get_world_size(pg)
|
||||
all_gather_object_enforce_type(pg, total_list, sync_obj)
|
||||
# Each rank will throw RuntimeError in case of failure on any rank.
|
||||
stage_name: Optional[str] = cast(SyncPayload[T], total_list[0]).stage_name
|
||||
stage_name = cast(SyncPayload[T], total_list[0]).stage_name
|
||||
exception_list: List[Tuple[int, Exception]] = []
|
||||
ret_list: List[T] = []
|
||||
error_msg: str = ""
|
||||
@ -160,7 +160,7 @@ def all_gather(
|
||||
ret_list.append(sp.payload)
|
||||
|
||||
if len(exception_list) > 0:
|
||||
raise RuntimeError(
|
||||
raise RuntimeError( # type: ignore[misc]
|
||||
error_msg, exception_list) from exception_list[0]
|
||||
return ret_list
|
||||
else:
|
||||
@ -168,7 +168,7 @@ def all_gather(
|
||||
raise RuntimeError(
|
||||
f"all_gather failed with exception {sync_obj.exception}",
|
||||
) from sync_obj.exception
|
||||
return [sync_obj.payload]
|
||||
return [sync_obj.payload] # type: ignore[list-item]
|
||||
|
||||
|
||||
# Note: use Any for typing for now so users can pass in
|
||||
|
@ -68,7 +68,7 @@ _default_metrics_handler: MetricHandler = NullMetricHandler()
|
||||
|
||||
|
||||
# pyre-fixme[9]: group has type `str`; used as `None`.
|
||||
def configure(handler: MetricHandler, group: str = None):
|
||||
def configure(handler: MetricHandler, group: Optional[str] = None):
|
||||
if group is None:
|
||||
global _default_metrics_handler
|
||||
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
|
||||
|
@ -158,7 +158,7 @@ class FileTimerServer:
|
||||
file_path: str,
|
||||
max_interval: float = 10,
|
||||
daemon: bool = True,
|
||||
log_event: Callable[[str, Optional[FileTimerRequest]], None] = None
|
||||
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None
|
||||
) -> None:
|
||||
self._file_path = file_path
|
||||
self._max_interval = max_interval
|
||||
|
@ -476,7 +476,7 @@ def _flatten_optim_state_dict(
|
||||
len(fqns) == 1
|
||||
), f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
|
||||
state = optim.state.get(param, None)
|
||||
state = optim.state.get(param, None) # type: ignore[call-overload]
|
||||
if state is not None:
|
||||
flat_osd_state[key] = copy.deepcopy(state)
|
||||
else:
|
||||
|
@ -1433,7 +1433,7 @@ def _register_post_backward_reshard_only_hooks(
|
||||
hook_handle = register_multi_grad_hook(
|
||||
inp_tensors, functools.partial(_post_backward_reshard, state, handle)
|
||||
)
|
||||
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined]
|
||||
handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment]
|
||||
|
||||
|
||||
@no_type_check
|
||||
@ -1468,11 +1468,11 @@ def _wait_for_computation_stream(
|
||||
For example, this should be called in the FSDP root's pre-forward to
|
||||
respect optimizer step computation.
|
||||
"""
|
||||
unshard_stream.wait_stream(computation_stream)
|
||||
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
|
||||
# Having the pre-all-gather stream wait for the current stream even if we
|
||||
# do not leverage the pre-all-gather stream is tolerable since this only
|
||||
# runs once per iteration
|
||||
pre_unshard_stream.wait_stream(computation_stream)
|
||||
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(
|
||||
|
@ -167,7 +167,7 @@ class _ExecOrderTracer:
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None,
|
||||
proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None,
|
||||
) -> torch.fx.Proxy:
|
||||
"""
|
||||
Overrides ``create_proxy`` to save execution information to
|
||||
|
@ -96,7 +96,7 @@ def _auto_wrap(
|
||||
)
|
||||
recursive_wrap_kwargs["auto_wrap_policy"] = policy
|
||||
_warn_on_overridden_mixed_precision(overridden_module_classes)
|
||||
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs)
|
||||
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _check_nested_wrapping(root_module: nn.Module):
|
||||
|
@ -129,8 +129,8 @@ class _RemoteModule(nn.Module):
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: Type[nn.Module],
|
||||
args: Tuple = None,
|
||||
kwargs: Dict[str, Any] = None,
|
||||
args: Optional[Tuple] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
_module_interface_cls: Any = None,
|
||||
):
|
||||
"""
|
||||
@ -358,7 +358,7 @@ class _RemoteModule(nn.Module):
|
||||
def bfloat16(self: T) -> T: # type: ignore[return]
|
||||
_raise_not_supported(self.bfloat16.__name__)
|
||||
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[return]
|
||||
def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var]
|
||||
_raise_not_supported(self.to.__name__)
|
||||
|
||||
def register_backward_hook( # type: ignore[return]
|
||||
@ -377,7 +377,7 @@ class _RemoteModule(nn.Module):
|
||||
) -> RemovableHandle:
|
||||
_raise_not_supported(self.register_forward_pre_hook.__name__)
|
||||
|
||||
def register_forward_hook( # type: ignore[return]
|
||||
def register_forward_hook( # type: ignore[return, override]
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
|
||||
@ -685,8 +685,8 @@ class RemoteModule(_RemoteModule):
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: Type[nn.Module],
|
||||
args: Tuple = None,
|
||||
kwargs: Dict[str, Any] = None,
|
||||
args: Optional[Tuple] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(remote_device, module_cls, args, kwargs)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import warnings
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Collection, Dict, List, Mapping, Union
|
||||
from typing import Any, Collection, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -63,8 +63,8 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
self,
|
||||
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
|
||||
optimizer_class: optim.Optimizer,
|
||||
param_groups: Collection[Mapping[str, Any]] = None,
|
||||
module: nn.Module = None,
|
||||
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
|
||||
module: Optional[nn.Module] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -154,7 +154,7 @@ class _NamedOptimizer(optim.Optimizer):
|
||||
self._optimizer.step(closure=closure)
|
||||
|
||||
@property
|
||||
def state(self) -> Mapping[torch.Tensor, Any]:
|
||||
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
|
||||
return self._optimizer.state
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
||||
|
@ -1,12 +1,12 @@
|
||||
from torch import nn
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ["partition_model"]
|
||||
|
||||
def partition_model(
|
||||
module: nn.Sequential,
|
||||
balance: List[int],
|
||||
devices: List[int] = None):
|
||||
devices: Optional[List[int]] = None):
|
||||
"""
|
||||
Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
|
||||
the model across multiple GPU devices according the provided ``balance``
|
||||
|
@ -49,7 +49,7 @@ def input_reshard(
|
||||
|
||||
def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any:
|
||||
nonlocal cx
|
||||
cx.__exit__() # type: ignore[name-defined]
|
||||
cx.__exit__() # type: ignore[name-defined, union-attr]
|
||||
|
||||
if input_reshard_dim is None:
|
||||
return module
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
import warnings
|
||||
from numbers import Number
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nan
|
||||
@ -72,9 +72,9 @@ class Wishart(ExponentialFamily):
|
||||
|
||||
def __init__(self,
|
||||
df: Union[torch.Tensor, Number],
|
||||
covariance_matrix: torch.Tensor = None,
|
||||
precision_matrix: torch.Tensor = None,
|
||||
scale_tril: torch.Tensor = None,
|
||||
covariance_matrix: Optional[torch.Tensor] = None,
|
||||
precision_matrix: Optional[torch.Tensor] = None,
|
||||
scale_tril: Optional[torch.Tensor] = None,
|
||||
validate_args=None):
|
||||
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
|
||||
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
|
||||
|
@ -23,7 +23,7 @@ class FoldedGraphModule(torch.fx.GraphModule):
|
||||
root: torch.nn.Module,
|
||||
graph: torch.fx.Graph,
|
||||
const_subgraph: Optional[torch.fx.Graph] = None,
|
||||
fx_const_folded_attrs_name: str = None,
|
||||
fx_const_folded_attrs_name: Optional[str] = None,
|
||||
device_for_folded_attrs: str = "cuda",
|
||||
):
|
||||
super().__init__(root, graph)
|
||||
|
@ -259,7 +259,7 @@ class MetaTracer(torch.fx.Tracer):
|
||||
|
||||
|
||||
def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
|
||||
meta_args : Dict[str, torch.Tensor] = None,
|
||||
meta_args : Optional[Dict[str, torch.Tensor]] = None,
|
||||
concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
|
||||
tracer = MetaTracer()
|
||||
graph = tracer.trace(root, meta_args, concrete_args)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# type: ignore[attr-defined]
|
||||
# mypy: disable-error-code=attr-defined
|
||||
from .core import unify, reify # noqa: F403
|
||||
from .more import unifiable # noqa: F403
|
||||
from .variable import var, isvar, vars, variables, Var # noqa: F403
|
||||
|
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
class Partition:
|
||||
def __init__(self, id: int = None, nodes: Iterable[Node] = None):
|
||||
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
|
||||
self.id = id
|
||||
self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from functools import wraps
|
||||
from inspect import unwrap
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -76,7 +76,7 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
||||
|
||||
|
||||
|
||||
def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None):
|
||||
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
|
||||
"""
|
||||
Convenience wrapper for passes which need to be applied multiple times.
|
||||
|
||||
|
@ -202,7 +202,7 @@ def replace_pattern_with_filters(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
"""
|
||||
@ -222,7 +222,7 @@ def _replace_pattern(
|
||||
gm: GraphModule,
|
||||
pattern: Union[Callable, GraphModule],
|
||||
replacement: Union[Callable, GraphModule],
|
||||
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
|
||||
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
|
||||
ignore_literals: bool = False,
|
||||
) -> List[ReplacedPatterns]:
|
||||
|
||||
|
@ -1302,7 +1302,7 @@ def amin(
|
||||
@_apply_docstring_templates
|
||||
def argmax(
|
||||
input: Union[Tensor, MaskedTensor],
|
||||
dim: int = None,
|
||||
dim: Optional[int] = None,
|
||||
*,
|
||||
keepdim: Optional[bool] = False,
|
||||
dtype: Optional[DType] = None,
|
||||
@ -1328,7 +1328,7 @@ def argmax(
|
||||
@_apply_docstring_templates
|
||||
def argmin(
|
||||
input: Union[Tensor, MaskedTensor],
|
||||
dim: int = None,
|
||||
dim: Optional[int] = None,
|
||||
*,
|
||||
keepdim: Optional[bool] = False,
|
||||
dtype: Optional[DType] = None,
|
||||
|
@ -103,7 +103,7 @@ class Sequential(Module):
|
||||
for idx, module in enumerate(args):
|
||||
self.add_module(str(idx), module)
|
||||
|
||||
def _get_item_by_idx(self, iterator, idx) -> T:
|
||||
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
|
||||
"""Get the idx-th item of the iterator"""
|
||||
size = len(self)
|
||||
idx = operator.index(idx)
|
||||
|
@ -51,7 +51,7 @@ class _ConvNd(Module):
|
||||
'out_channels', 'kernel_size']
|
||||
__annotations__ = {'bias': Optional[torch.Tensor]}
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
in_channels: int
|
||||
|
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterator, Set, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -148,7 +148,7 @@ def functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
@ -233,7 +233,7 @@ def _functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Union[Any, Tuple],
|
||||
kwargs: Dict[str, Any] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
|
@ -340,7 +340,7 @@ class ProtobufExportOutputSerializer:
|
||||
) -> None:
|
||||
import onnx
|
||||
|
||||
if not isinstance(export_output.model_proto, onnx.ModelProto):
|
||||
if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
|
||||
raise ValueError("export_output.ModelProto is not an onnx.ModelProto")
|
||||
destination.write(export_output.model_proto.SerializeToString())
|
||||
|
||||
@ -348,7 +348,7 @@ class ProtobufExportOutputSerializer:
|
||||
class ExportOutput:
|
||||
"""An in-memory representation of a PyTorch model that has been exported to ONNX."""
|
||||
|
||||
_model_proto: Final[onnx.ModelProto]
|
||||
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined]
|
||||
_input_adapter: Final[io_adapter.InputAdapter]
|
||||
_output_adapter: Final[io_adapter.OutputAdapter]
|
||||
_diagnostic_context: Final[infra.DiagnosticContext]
|
||||
@ -356,7 +356,7 @@ class ExportOutput:
|
||||
@_beartype.beartype
|
||||
def __init__(
|
||||
self,
|
||||
model_proto: onnx.ModelProto,
|
||||
model_proto: onnx.ModelProto, # type: ignore[name-defined]
|
||||
input_adapter: io_adapter.InputAdapter,
|
||||
output_adapter: io_adapter.OutputAdapter,
|
||||
diagnostic_context: infra.DiagnosticContext,
|
||||
@ -367,7 +367,7 @@ class ExportOutput:
|
||||
self._diagnostic_context = diagnostic_context
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto:
|
||||
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
|
||||
"""The exported ONNX model as an ``onnx.ModelProto``."""
|
||||
|
||||
return self._model_proto
|
||||
|
@ -193,7 +193,7 @@ class DynamoExport(exporter.FXGraphExtractor):
|
||||
wrapped_model,
|
||||
*model_args,
|
||||
tracing_mode=fx_mode,
|
||||
fake_mode=fake_mode,
|
||||
fake_mode=fake_mode, # type: ignore[arg-type]
|
||||
**model_kwargs,
|
||||
)
|
||||
del graph_guard # Unused
|
||||
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
@_beartype.beartype
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor, name: str, location: str, basepath: str
|
||||
) -> onnx.TensorProto:
|
||||
) -> onnx.TensorProto: # type: ignore[name-defined]
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
|
||||
@ -38,13 +38,13 @@ def _create_tensor_proto_with_external_data(
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
tensor_proto = onnx.TensorProto()
|
||||
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
|
||||
tensor.dtype
|
||||
).onnx_type()
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
|
||||
|
||||
# Settings for saving one tensor per file.
|
||||
# Offset is zero because there is no other tensor in the same file.
|
||||
@ -86,7 +86,7 @@ def save_model_with_external_data(
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_load_paths: Tuple[Union[str, io.BytesIO], ...],
|
||||
onnx_model: onnx.ModelProto,
|
||||
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
|
||||
rename_initializer: bool = False,
|
||||
) -> None:
|
||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||
@ -121,7 +121,7 @@ def save_model_with_external_data(
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
onnx_model_with_initializers = onnx.ModelProto()
|
||||
onnx_model_with_initializers = onnx.ModelProto() # type: ignore[attr-defined]
|
||||
onnx_model_with_initializers.CopyFrom(onnx_model)
|
||||
onnx_input_names = [input.name for input in onnx_model.graph.input]
|
||||
|
||||
@ -159,4 +159,4 @@ def save_model_with_external_data(
|
||||
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
|
||||
|
||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
|
||||
onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) # type: ignore[attr-defined]
|
||||
|
@ -62,7 +62,7 @@ def export_as_test_case(
|
||||
shutil.rmtree(data_set_dir)
|
||||
os.makedirs(data_set_dir)
|
||||
|
||||
proto = onnx.load_model_from_string(model_bytes)
|
||||
proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
|
||||
|
||||
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
|
||||
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
|
||||
@ -112,12 +112,12 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
|
||||
inputs = {}
|
||||
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
|
||||
for input_file in input_files:
|
||||
tensor = onnx.load_tensor(input_file)
|
||||
tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
|
||||
inputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
outputs = {}
|
||||
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
|
||||
for output_file in output_files:
|
||||
tensor = onnx.load_tensor(output_file)
|
||||
tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
|
||||
outputs[tensor.name] = numpy_helper.to_array(tensor)
|
||||
|
||||
return model_bytes, inputs, outputs
|
||||
@ -227,7 +227,7 @@ def _add_onnxscript_fn(
|
||||
# size > 2GB, and if it for some reason did not, the model would fail on
|
||||
# serialization anyway in terms of the protobuf limitation. So we don't
|
||||
# need to worry about > 2GB model getting here.
|
||||
model_proto = onnx.load_model_from_string(model_bytes)
|
||||
model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
|
||||
|
||||
# Iterate graph nodes to insert only the included custom
|
||||
# function_proto into model_proto
|
||||
|
@ -206,14 +206,14 @@ def _ort_session(
|
||||
def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
|
||||
try:
|
||||
import onnx
|
||||
from onnx import reference as onnx_reference
|
||||
from onnx import reference as onnx_reference # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
raise ImportError("onnx >= 1.13 is required for reference evaluator.")
|
||||
|
||||
proto = (
|
||||
onnx.load(model)
|
||||
onnx.load(model) # type: ignore[attr-defined]
|
||||
if isinstance(model, str)
|
||||
else onnx.load_model_from_string(model.getvalue())
|
||||
else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined]
|
||||
)
|
||||
onnx_session = onnx_reference.ReferenceEvaluator(proto)
|
||||
return onnx_session
|
||||
|
@ -6,7 +6,7 @@ import warnings
|
||||
import functools
|
||||
import math
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Tuple, Optional
|
||||
from torch import Tensor
|
||||
|
||||
import torch.utils.hooks as hooks
|
||||
@ -208,7 +208,7 @@ class Optimizer:
|
||||
"an iterable of Tensors or dicts, but got " +
|
||||
torch.typename(params))
|
||||
|
||||
self.state = defaultdict(dict)
|
||||
self.state: Dict[int, Any] = defaultdict(dict)
|
||||
self.param_groups = []
|
||||
|
||||
param_groups = list(params)
|
||||
@ -340,8 +340,8 @@ class Optimizer:
|
||||
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
|
||||
hooked = getattr(self.__class__.step, "hooked", None)
|
||||
if not hooked:
|
||||
self.__class__.step = self.profile_hook_step(self.__class__.step)
|
||||
self.__class__.step.hooked = True
|
||||
self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[method-assign]
|
||||
self.__class__.step.hooked = True # type: ignore[attr-defined]
|
||||
|
||||
def register_step_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
|
||||
r"""Register an optimizer step pre hook which will be called before
|
||||
@ -418,14 +418,15 @@ class Optimizer:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: int = None,
|
||||
param_groups: List[Dict[Any, Any]] = None, key=None) -> Tensor:
|
||||
def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: Optional[int] = None,
|
||||
param_groups: Optional[List[Dict[Any, Any]]] = None, key=None) -> Tensor:
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
# UNLESS fused or capturable, see note [special device hosting for step]
|
||||
fused = False
|
||||
capturable = False
|
||||
assert param_groups is not None
|
||||
for pg in param_groups:
|
||||
if param_id in pg["params"]:
|
||||
fused = pg["fused"] if "fused" in pg else False
|
||||
@ -477,14 +478,14 @@ class Optimizer:
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
|
||||
return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
state: Dict[Any, Dict[Any, Any]] = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
@ -521,7 +522,7 @@ class Optimizer:
|
||||
if not hasattr(self, "_zero_grad_profile_name"):
|
||||
self._patch_step_function()
|
||||
if foreach:
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
|
||||
per_device_and_dtype_grads: Dict[Any, Dict[Any, List[Any]]] = defaultdict(lambda: defaultdict(list))
|
||||
with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -603,7 +603,7 @@ def input_dtypes(event: _ProfilerEvent):
|
||||
def report_all_anti_patterns(prof,
|
||||
should_benchmark: bool = False,
|
||||
print_enable: bool = True,
|
||||
json_report_dir: str = None):
|
||||
json_report_dir: Optional[str] = None):
|
||||
report_dict: Dict = {}
|
||||
anti_patterns = [
|
||||
ExtraCUDACopyPattern(prof, should_benchmark),
|
||||
|
@ -242,7 +242,7 @@ class _KinetoProfile:
|
||||
assert self.profiler is not None and self.profiler.kineto_results is not None
|
||||
return MemoryProfile(self.profiler.kineto_results)
|
||||
|
||||
def export_memory_timeline(self, path: str, device: str = None) -> None:
|
||||
def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None:
|
||||
"""Extract the memory information from the memory profile collected
|
||||
tree for a given device, and export a timeline plot consisting of
|
||||
[times, [sizes by category]], where times are timestamps and sizes
|
||||
|
@ -861,7 +861,7 @@ def load(
|
||||
pickle_module: Any = None,
|
||||
*,
|
||||
weights_only: bool = False,
|
||||
mmap: bool = None,
|
||||
mmap: Optional[bool] = None,
|
||||
**pickle_load_args: Any
|
||||
) -> Any:
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
||||
|
@ -422,9 +422,9 @@ Examples::
|
||||
def hamming(M: int,
|
||||
*,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
return general_hamming(M, sym=sym, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
|
||||
|
||||
@ -469,9 +469,9 @@ Examples::
|
||||
def hann(M: int,
|
||||
*,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
return general_hamming(M,
|
||||
alpha=0.5,
|
||||
@ -521,9 +521,9 @@ Examples::
|
||||
def blackman(M: int,
|
||||
*,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -575,9 +575,9 @@ Examples::
|
||||
def bartlett(M: int,
|
||||
*,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -644,9 +644,9 @@ Examples::
|
||||
def general_cosine(M, *,
|
||||
a: Iterable,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -721,9 +721,9 @@ def general_hamming(M,
|
||||
*,
|
||||
alpha: float = 0.54,
|
||||
sym: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
requires_grad: bool = False) -> Tensor:
|
||||
return general_cosine(M,
|
||||
a=[alpha, 1. - alpha],
|
||||
|
@ -199,7 +199,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||
self.compressed_tensor = compressed_tensor
|
||||
self.transposed = transposed
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self) -> str: # type: ignore[override]
|
||||
"""Return string representation of SparseSemiStructuredTensor
|
||||
|
||||
Returns:
|
||||
|
@ -3,7 +3,7 @@ import io
|
||||
import torch
|
||||
from ._utils import _type, _cuda, _hpu
|
||||
from torch.types import Storage
|
||||
from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict
|
||||
from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union
|
||||
import copy
|
||||
import collections
|
||||
from functools import lru_cache
|
||||
@ -27,51 +27,51 @@ class _StorageBase:
|
||||
device: torch.device
|
||||
|
||||
def __init__(self, *args, **kwargs): ... # noqa: E704
|
||||
def __len__(self) -> int: ... # noqa: E704
|
||||
def __len__(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
def __getitem__(self, idx): ... # noqa: E704
|
||||
def __setitem__(self, *args, **kwargs): ... # noqa: E704
|
||||
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
|
||||
def new(self) -> T: ... # noqa: E704
|
||||
def nbytes(self) -> int: ... # noqa: E704
|
||||
def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def new(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def nbytes(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
|
||||
def size(self) -> int:
|
||||
return self.nbytes()
|
||||
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
||||
def element_size(self) -> int: ... # noqa: E704
|
||||
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704
|
||||
|
||||
def get_device(self) -> int:
|
||||
return self.device.index
|
||||
|
||||
def data_ptr(self) -> int: ... # noqa: E704
|
||||
def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704
|
||||
|
||||
# Defined in torch/csrc/generic/StorageSharing.cpp
|
||||
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def from_buffer(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
|
||||
def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _shared_decref(self) -> T: ... # noqa: E704
|
||||
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _write_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def resize_(self, size: int): ... # noqa: E704
|
||||
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
|
||||
def is_shared(self) -> bool: ... # noqa: E704
|
||||
def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704
|
||||
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
|
||||
@ -80,9 +80,9 @@ class _StorageBase:
|
||||
@property
|
||||
def is_hpu(self): ... # noqa: E704
|
||||
@classmethod
|
||||
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
|
||||
def from_file(cls, filename, shared, nbytes) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
@classmethod
|
||||
def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _expired(cls, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704
|
||||
def _byteswap(self, *args, **kwargs): ... # noqa: E704
|
||||
|
||||
def __str__(self):
|
||||
@ -712,12 +712,12 @@ class TypedStorage:
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
|
||||
return tmp_tensor[idx_wrapped].item()
|
||||
|
||||
def copy_(self, source: T, non_blocking: bool = None):
|
||||
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
|
||||
_warn_typed_storage_removal()
|
||||
if isinstance(source, TypedStorage):
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
|
||||
self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type]
|
||||
else:
|
||||
self._untyped_storage.copy_(source, non_blocking)
|
||||
self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type]
|
||||
return self
|
||||
|
||||
def nbytes(self):
|
||||
@ -728,7 +728,7 @@ class TypedStorage:
|
||||
def _nbytes(self):
|
||||
return self._untyped_storage.nbytes()
|
||||
|
||||
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
|
||||
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]:
|
||||
_warn_typed_storage_removal()
|
||||
if dtype is None:
|
||||
legacy_class = self._get_legacy_storage_class()
|
||||
@ -741,14 +741,14 @@ class TypedStorage:
|
||||
else:
|
||||
return self._untyped_storage.type(dtype, non_blocking)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var]
|
||||
_warn_typed_storage_removal()
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create HPU storage with quantized dtype")
|
||||
|
@ -43,37 +43,37 @@ class Storage:
|
||||
dtype: torch.dtype
|
||||
_torch_load_uninitialized: bool
|
||||
|
||||
def __deepcopy__(self, memo) -> 'Storage':
|
||||
def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _new_shared(self, int) -> 'Storage':
|
||||
def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
|
||||
...
|
||||
|
||||
def element_size(self) -> int:
|
||||
def element_size(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def is_shared(self) -> bool:
|
||||
def is_shared(self) -> bool: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def share_memory_(self) -> 'Storage':
|
||||
def share_memory_(self) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def nbytes(self) -> int:
|
||||
def nbytes(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def cpu(self) -> 'Storage':
|
||||
def cpu(self) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def data_ptr(self) -> int:
|
||||
def data_ptr(self) -> int: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
|
||||
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
|
||||
def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
...
|
||||
|
@ -100,7 +100,7 @@ def report_compile_source_on_error():
|
||||
# specifically _PyCode_InitAddressRange, reveals that
|
||||
# this iterator is initialized from co_linetable and
|
||||
# co_firstfileno. So copy these we must!
|
||||
code = code.replace(
|
||||
code = code.replace( # type: ignore[call-arg]
|
||||
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
|
||||
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
|
||||
)
|
||||
|
@ -184,7 +184,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
|
||||
|
||||
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
unsupported_dtype: List[torch.dtype] = None) -> None:
|
||||
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
|
||||
# Attribute is registered in the _StorageBase class
|
||||
# and UntypedStorage obtains through inheritance.
|
||||
@property # type: ignore[misc]
|
||||
@ -255,7 +255,7 @@ def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
|
||||
|
||||
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
|
||||
for_storage: bool = False,
|
||||
unsupported_dtype: List[torch.dtype] = None) -> None:
|
||||
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
|
||||
r"""
|
||||
generate_methods_for_privateuse1_backend(for_tensor, for_module, for_storage, unsupported_dtype) -> None
|
||||
|
||||
|
@ -37,8 +37,8 @@ if HAS_TABULATE:
|
||||
model: Union[torch.nn.Module, Callable],
|
||||
sample_input: Union[torch.Tensor, Any],
|
||||
num_iters: int = 5,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
loss_fn: Callable = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
):
|
||||
# Define the statement and setup for the benchmark
|
||||
if optimizer and loss_fn:
|
||||
@ -73,8 +73,8 @@ if HAS_TABULATE:
|
||||
num_iters: int = 5,
|
||||
backend: Optional[str] = None,
|
||||
mode: Optional[str] = "default",
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
loss_fn : Union[torch.nn.Module, Callable] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
loss_fn : Union[torch.nn.Module, Callable, None] = None,
|
||||
):
|
||||
"""
|
||||
Use this utility to benchmark torch.compile
|
||||
@ -117,7 +117,7 @@ if HAS_TABULATE:
|
||||
sample_input: Union[torch.Tensor, Any],
|
||||
num_iters : int = 5,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
loss_fn : Union[torch.nn.Module, Callable] = None,
|
||||
loss_fn : Union[torch.nn.Module, Callable, None] = None,
|
||||
):
|
||||
"""
|
||||
This is a simple utility that can be used to benchmark torch.compile
|
||||
|
@ -1066,7 +1066,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
# pin_memory_thread once it is started.
|
||||
self._pin_memory_thread = pin_memory_thread
|
||||
else:
|
||||
self._data_queue = self._worker_result_queue
|
||||
self._data_queue = self._worker_result_queue # type: ignore[assignment]
|
||||
|
||||
# In some rare cases, persistent workers (daemonic processes)
|
||||
# would be terminated before `__del__` of iterator is invoked
|
||||
|
@ -415,7 +415,7 @@ class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
|
||||
|
||||
@functional_datapipe('trace_as_dataframe')
|
||||
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
|
||||
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc]
|
||||
source_datapipe = None
|
||||
|
||||
# TODO(VitalyFedyunin): Must implement all special functions of datapipes
|
||||
|
@ -376,7 +376,7 @@ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataP
|
||||
self._datapipe_iter = iter(self._datapipe)
|
||||
return self
|
||||
|
||||
def __next__(self) -> T_co:
|
||||
def __next__(self) -> T_co: # type: ignore[type-var]
|
||||
assert self._datapipe_iter is not None
|
||||
return next(self._datapipe_iter)
|
||||
|
||||
|
@ -421,7 +421,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
||||
self.main_datapipe_exhausted = False
|
||||
self._child_stop: List[bool] = [True for _ in range(num_instances)]
|
||||
|
||||
def _find_next(self, instance_id: int) -> T_co:
|
||||
def _find_next(self, instance_id: int) -> T_co: # type: ignore[type-var]
|
||||
while True:
|
||||
if self.main_datapipe_exhausted or self._child_stop[instance_id]:
|
||||
raise StopIteration
|
||||
|
@ -40,7 +40,7 @@ class ConcaterMapDataPipe(MapDataPipe):
|
||||
raise TypeError("Expected all inputs to be `Sized`")
|
||||
self.datapipes = datapipes # type: ignore[assignment]
|
||||
|
||||
def __getitem__(self, index) -> T_co:
|
||||
def __getitem__(self, index) -> T_co: # type: ignore[type-var]
|
||||
offset = 0
|
||||
for dp in self.datapipes:
|
||||
if index - offset < len(dp):
|
||||
|
@ -418,5 +418,5 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
|
||||
if sum(lengths) != len(dataset): # type: ignore[arg-type]
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
|
||||
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
|
||||
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||
|
@ -236,7 +236,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
|
||||
depth: int = 2,
|
||||
display: bool = True,
|
||||
custom_mapping: Dict[Any, Any] = None):
|
||||
custom_mapping: Optional[Dict[Any, Any]] = None):
|
||||
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self.depth = depth
|
||||
self.parents = ["Global"]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import gc
|
||||
import sys
|
||||
from typing import NamedTuple, Tuple, List, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
import types
|
||||
import weakref
|
||||
import json
|
||||
@ -100,7 +100,7 @@ def annotated_references(obj):
|
||||
need for a list. Descriptions are currently strings.
|
||||
|
||||
"""
|
||||
references = {}
|
||||
references: Dict[int, List[str]] = {}
|
||||
|
||||
def add_reference(name, obj):
|
||||
references.setdefault(id(obj), []).append(name)
|
||||
@ -272,7 +272,7 @@ def create_graph(objects, *, context=None, filter=None):
|
||||
filter = is_cuda_tensor
|
||||
|
||||
nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
|
||||
node_referrers = [[] for obj in objects]
|
||||
node_referrers: List[List[int]] = [[] for obj in objects]
|
||||
|
||||
id_to_node = {id(obj): i for i, obj in enumerate(objects)}
|
||||
for obj in objects:
|
||||
@ -299,8 +299,8 @@ def create_graph(objects, *, context=None, filter=None):
|
||||
to_keep.add(idx)
|
||||
referrers = node_referrers[idx]
|
||||
to_search.extend(referrers)
|
||||
id_to_filtered_id = {}
|
||||
filtered = []
|
||||
id_to_filtered_id: Dict[int, int] = {}
|
||||
filtered: List[Any] = []
|
||||
for i, n in enumerate(nodes):
|
||||
if i in to_keep:
|
||||
id_to_filtered_id[i] = len(id_to_filtered_id)
|
||||
|
@ -50,7 +50,7 @@ class WeakIdRef(weakref.ref):
|
||||
# cache the id of the key as we know this is definitely the hash
|
||||
# method
|
||||
self._id = id(key)
|
||||
super().__init__(key, callback)
|
||||
super().__init__(key, callback) # type: ignore[call-arg]
|
||||
|
||||
def __call__(self):
|
||||
r = super().__call__()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user