Apply Ruff fixes and pyupgrade to torch/jit (#144208)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144208
Approved by: https://github.com/davidberard98
This commit is contained in:
cyy
2025-01-16 00:28:48 +00:00
committed by PyTorch MergeBot
parent 774f21a370
commit ee97d80be2
20 changed files with 298 additions and 271 deletions

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import warnings
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Iterator
from typing import Any
import torch._C

View File

@ -50,11 +50,17 @@ def fork(func, *args, **kwargs):
import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
def foo(a: Tensor, b: int) -> Tensor:
return a + b
def bar(a):
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
@ -69,16 +75,23 @@ def fork(func, *args, **kwargs):
import torch
from torch import Tensor
class AddMod(torch.nn.Module):
def forward(self, a: Tensor, b : int):
def forward(self, a: Tensor, b: int):
return a + b
class Mod(torch.nn.Module):
def __init__(self) -> None:
super(self).__init__()
self.mod = AddMod()
def forward(self, input):
fut = torch.jit.fork(self.mod, a, b=2)
return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

View File

@ -3,7 +3,7 @@ import cmath
import math
import warnings
from collections import OrderedDict
from typing import Dict, Optional
from typing import Optional
import torch
import torch.backends.cudnn as cudnn
@ -16,7 +16,7 @@ from torch.nn.modules.utils import (
)
_builtin_table: Optional[Dict[int, str]] = None
_builtin_table: Optional[dict[int, str]] = None
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950

View File

@ -5,7 +5,7 @@ import dataclasses
import inspect
import os
from functools import partial
from typing import Callable, Dict, List
from typing import Callable
from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
from torch._sources import ParsedDef, SourceContext
@ -15,7 +15,7 @@ def _get_fake_filename(cls, method_name):
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
def compose_fn(cls, name: str, body_lines: list[str], signature: str) -> ParsedDef:
body = "\n".join(f" {b}" for b in body_lines)
decl = f"def {name}{signature}:\n{body}"
@ -59,7 +59,7 @@ def synthesize__init__(cls) -> ParsedDef:
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
init_vars: List[str] = []
init_vars: list[str] = []
params = []
for name, param in signature.parameters.items():
ann = param.annotation
@ -144,7 +144,7 @@ def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
def synthesize_comparison(
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: list[str]
) -> ParsedDef:
body = []
for field in dataclasses.fields(cls):
@ -177,7 +177,7 @@ def synthesize_comparison(
)
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
DATACLASS_MAGIC_METHODS: dict[str, Callable] = {
"__init__": synthesize__init__,
"__repr__": synthesize__repr__,
"__hash__": synthesize__hash__,

View File

@ -6,14 +6,14 @@ from torch import Tensor
aten = torch.ops.aten
import inspect
import warnings
from typing import Callable, Dict, List, Optional, Set, TypeVar
from typing import Callable, Optional, TypeVar
from typing_extensions import ParamSpec
from torch.types import Number
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
function_name_set: Set[str] = set()
decomposition_table: dict[str, torch.jit.ScriptFunction] = {}
function_name_set: set[str] = set()
_T = TypeVar("_T")
_P = ParamSpec("_P")
@ -65,7 +65,7 @@ def signatures_match(decomposition_sig, torch_op_sig):
def register_decomposition(
aten_op: torch._ops.OpOverload,
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
registry: Optional[dict[str, torch.jit.ScriptFunction]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
nonlocal registry
@ -99,12 +99,12 @@ def register_decomposition(
@register_decomposition(aten.var.correction)
def var_decomposition(
input: Tensor,
dim: Optional[List[int]] = None,
dim: Optional[list[int]] = None,
correction: Optional[Number] = None,
keepdim: bool = False,
) -> Tensor:
if dim is None:
dim_i: List[int] = []
dim_i: list[int] = []
dim = dim_i
if isinstance(dim, (tuple, list)) and len(dim) == 0:

View File

@ -5,14 +5,14 @@ This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
from typing import List, Optional
from typing import Optional
import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule
def freeze(
mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True
mod, preserved_attrs: Optional[list[str]] = None, optimize_numerics: bool = True
):
r"""Freeze ScriptModule, inline submodules, and attributes as constants.
@ -124,7 +124,7 @@ def freeze(
def run_frozen_optimizations(
mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None
mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None
):
r"""
Run a series of optimizations looking for patterns that occur in frozen graphs.
@ -155,9 +155,12 @@ def run_frozen_optimizations(
Example (Freezing a module with Conv->Batchnorm)
.. code-block:: python
import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=True
)
bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
@ -180,7 +183,7 @@ def run_frozen_optimizations(
def optimize_for_inference(
mod: ScriptModule, other_methods: Optional[List[str]] = None
mod: ScriptModule, other_methods: Optional[list[str]] = None
) -> ScriptModule:
"""
Perform a set of optimization passes to optimize a model for the purposes of inference.
@ -202,9 +205,12 @@ def optimize_for_inference(
Example (optimizing a module with Conv->Batchnorm)::
import torch
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=True
)
bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
mod = torch.nn.Sequential(conv, bn)
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
assert "batch_norm" not in str(frozen_mod.graph)

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
from typing import List, Tuple
import torch
@ -106,7 +105,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs):
# graph_executor_states for differentiable node
fw_states = eps[0].code.differentiable_op_executor_states()
diff_nodes: List[torch._C.Node] = []
diff_nodes: list[torch._C.Node] = []
for n in graph.nodes():
_get_differentiable_graph_node(n, diff_nodes)
@ -128,7 +127,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs):
return last_executed_optimized_graph()
def set_fusion_strategy(strategy: List[Tuple[str, int]]):
def set_fusion_strategy(strategy: list[tuple[str, int]]):
"""Set the type and number of specializations that can occur during fusion.
Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"

View File

@ -1,5 +1,5 @@
from types import TracebackType
from typing import Optional, Type, Union
from typing import Optional, Union
import torch
@ -20,7 +20,7 @@ class _InsertPoint:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:

View File

@ -3,9 +3,10 @@ import inspect
import sys
import typing
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from types import CodeType
from typing import Dict, Iterable, List, Optional
from typing import Optional
import torch
@ -94,7 +95,7 @@ if _IS_MONKEYTYPE_INSTALLED:
# A dictionary keeping all collected CallTrace
# key is fully qualified name of called function
# value is list of all CallTrace
self.trace_records: Dict[str, list] = defaultdict(list)
self.trace_records: dict[str, list] = defaultdict(list)
def add(self, traces: Iterable[CallTrace]):
for t in traces:
@ -106,10 +107,10 @@ if _IS_MONKEYTYPE_INSTALLED:
qualified_name: str,
qualname_prefix: Optional[str] = None,
limit: int = 2000,
) -> List[CallTraceThunk]:
) -> list[CallTraceThunk]:
return self.trace_records[qualified_name]
def analyze(self, qualified_name: str) -> Dict:
def analyze(self, qualified_name: str) -> dict:
# Analyze the types for the given module
# and create a dictionary of all the types
# for arguments.
@ -120,7 +121,7 @@ if _IS_MONKEYTYPE_INSTALLED:
all_args[arg].add(arg_type)
return all_args
def consolidate_types(self, qualified_name: str) -> Dict:
def consolidate_types(self, qualified_name: str) -> dict:
all_args = self.analyze(qualified_name)
# If there are more types for an argument,
# then consolidate the type to `Any` and replace the entry
@ -137,7 +138,7 @@ if _IS_MONKEYTYPE_INSTALLED:
all_args[arg] = get_type(types[0])
return all_args
def get_args_types(self, qualified_name: str) -> Dict:
def get_args_types(self, qualified_name: str) -> dict:
return self.consolidate_types(qualified_name)
class JitTypeTraceConfig(monkeytype.config.Config):

View File

@ -5,14 +5,14 @@ This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
from typing import Any, List
from typing import Any
import torch
from torch import TensorType
from torch._C import Graph
def apply_input_props_using_example(graph: Graph, example_input: List[Any]) -> None:
def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None:
"""
Applies properties for each tensor in the graph inputs
using the example supplied.

View File

@ -6,7 +6,6 @@ import sys
import textwrap
import types
import warnings
from typing import Dict, List, Set, Type
import torch
import torch._jit_internal as _jit_internal
@ -421,8 +420,8 @@ def infer_concrete_type_builder(nn_module, share_types=True):
class ConcreteTypeStore:
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
methods_compiled: Set[torch._C.ConcreteModuleType]
type_store: dict[type[Module], list[torch._C.ConcreteModuleType]]
methods_compiled: set[torch._C.ConcreteModuleType]
def __init__(self) -> None:
# Python module type => List[ConcreteModuleType)]
@ -766,7 +765,7 @@ def get_overload_annotations(mod, jit_ignored_properties):
def get_overload_name_mapping(overload_info):
# Same format as __overloads__
# original function => [overload names]
overload_name_mappings: Dict[str, List[str]] = {}
overload_name_mappings: dict[str, list[str]] = {}
for orig_fn, overloads in overload_info.items():
original_name = orig_fn.__name__
if original_name not in overload_name_mappings:
@ -836,7 +835,7 @@ def infer_methods_to_compile(nn_module):
check_module_initialized(nn_module)
ignored_properties = jit_ignored_properties(nn_module)
methods: List[str] = []
methods: list[str] = []
if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
nn_module.forward
):
@ -873,7 +872,7 @@ def infer_methods_to_compile(nn_module):
# Unique the methods. We don't want to use a set to store the methods because it
# introduces non-determinism to compile order.
uniquer: Set[str] = set()
uniquer: set[str] = set()
uniqued_methods = []
for name in filtered_methods:
if name in uniquer:
@ -888,7 +887,7 @@ def infer_methods_to_compile(nn_module):
def get_hook_stubs(nn_module):
"""Return forward hook and pre_hook ScriptModuleStubs."""
check_module_initialized(nn_module)
hook_map: Dict = {}
hook_map: dict = {}
hook_stubs = []
for hook in nn_module._forward_hooks.values():

View File

@ -6,6 +6,7 @@ This module contains functionality to support the JIT's scripting frontend, nota
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import collections
import copy
import enum
@ -13,7 +14,7 @@ import functools
import inspect
import pickle
import warnings
from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, Union
import torch
import torch._jit_internal as _jit_internal
@ -277,12 +278,12 @@ class OrderedModuleDict(OrderedDictWrapper):
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# Aggregate all the ScriptMethods and constants from superclasses
cls._methods: Dict[str, Any] = {}
cls._methods: dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants: Set = getattr(base, "_constants_set", set())
base_constants: set = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# find all the script methods of the current class
@ -1020,7 +1021,9 @@ def call_prepare_scriptable_func_impl(obj, memo):
if obj_id in memo:
return memo[id(obj)]
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator]
obj = (
obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj
) # type: ignore[operator]
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj
@ -1046,7 +1049,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
memo: dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)
@ -1089,7 +1092,7 @@ def _script_impl(
optimize=None,
_frames_up=0,
_rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None,
):
global type_trace_db
@ -1118,7 +1121,7 @@ def _script_impl(
if monkeytype_trace:
monkeytype_config = JitTypeTraceConfig(type_trace_db)
with monkeytype_trace(monkeytype_config):
if isinstance(example_inputs, Dict):
if isinstance(example_inputs, dict):
# If the obj is an nn.Module or a class, then each method is
# executed with the arguments provided in the example inputs.
# example inputs here will be of type Dict(class.method, (arguments))
@ -1127,7 +1130,7 @@ def _script_impl(
for module, example_input in example_inputs.items():
for example in example_input:
module(*example)
elif isinstance(example_inputs, List):
elif isinstance(example_inputs, list):
for examples in example_inputs:
obj(*examples)
else:
@ -1148,7 +1151,11 @@ def _script_impl(
obj, torch.jit._recursive.infer_methods_to_compile
)
else:
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator]
obj = (
obj.__prepare_scriptable__()
if hasattr(obj, "__prepare_scriptable__")
else obj
) # type: ignore[operator]
if isinstance(obj, dict):
return create_script_dict(obj)
@ -1220,7 +1227,7 @@ def script(
optimize=None,
_frames_up=0,
_rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None,
):
r"""Script the function.
@ -1374,6 +1381,7 @@ def script(
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -1387,6 +1395,7 @@ def script(
# This function won't be compiled, so any
# Python APIs can be used
import pdb
pdb.set_trace()
def forward(self, input):
@ -1394,6 +1403,7 @@ def script(
self.python_only_fn(input)
return input * 99
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
@ -1619,14 +1629,14 @@ class _ScriptProfileColumn:
self.header = header
self.alignment = alignment
self.offset = offset
self.rows: Dict[int, Any] = {}
self.rows: dict[int, Any] = {}
def add_row(self, lineno: int, value: Any):
self.rows[lineno] = value
def materialize(self):
max_length = len(self.header)
rows: List[Tuple[int, str]] = []
rows: list[tuple[int, str]] = []
for key, value in self.rows.items():
cell = str(value)
rows.append((key, cell))
@ -1643,13 +1653,13 @@ class _ScriptProfileColumn:
class _ScriptProfileTable:
def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
def __init__(self, cols: list[_ScriptProfileColumn], source_range: list[int]):
self.cols = cols
self.source_range = source_range
def dump_string(self):
outputs: List[str] = []
cells: List[Tuple[str, Dict[int, str]]] = []
outputs: list[str] = []
cells: list[tuple[str, dict[int, str]]] = []
header_buffer = ""
for col in self.cols:
header, rows = col.materialize()
@ -1681,7 +1691,7 @@ class _ScriptProfile:
self.profile.disable()
def dump_string(self) -> str:
outputs: List[str] = []
outputs: list[str] = []
for source_stats in self.profile._dump_stats():
source_ref = source_stats.source()
source_lines = source_ref.text().splitlines()

View File

@ -165,7 +165,9 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
cu = torch._C.CompilationUnit()
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
cpp_module = torch._C.import_ir_module(
cu, os.fspath(f), map_location, _extra_files, _restore_shapes
) # type: ignore[call-arg]
else:
cpp_module = torch._C.import_ir_module_from_buffer(
cu, f.read(), map_location, _extra_files, _restore_shapes

View File

@ -24,11 +24,11 @@ number = Union[int, float]
import torch
def broadcast(a: List[int], b: List[int]):
def broadcast(a: list[int], b: list[int]):
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
expandedSizes: List[int] = []
expandedSizes: list[int] = []
for i in range(ndim):
offset = ndim - 1 - i
@ -48,21 +48,21 @@ def broadcast(a: List[int], b: List[int]):
return expandedSizes
def broadcast_three(a: List[int], b: List[int], c: List[int]):
def broadcast_three(a: list[int], b: list[int], c: list[int]):
return broadcast(broadcast(a, b), c)
def broadcast_one_three(a: List[int], b: Any, c: List[int]):
def broadcast_one_three(a: list[int], b: Any, c: list[int]):
return broadcast(a, c)
def adaptive_avg_pool2d(self: List[int], out: List[int]):
def adaptive_avg_pool2d(self: list[int], out: list[int]):
assert len(out) == 2
assert len(self) == 3 or len(self) == 4
for i in range(1, len(self)):
assert self[i] != 0
shape: List[int] = []
shape: list[int] = []
for i in range(0, len(self) - 2):
shape.append(self[i])
for elem in out:
@ -70,18 +70,18 @@ def adaptive_avg_pool2d(self: List[int], out: List[int]):
return shape
def _copy(self: List[int]):
out: List[int] = []
def _copy(self: list[int]):
out: list[int] = []
for elem in self:
out.append(elem)
return out
def unary(self: List[int]):
def unary(self: list[int]):
return _copy(self)
def broadcast_inplace(a: List[int], b: List[int]):
def broadcast_inplace(a: list[int], b: list[int]):
dimsA = len(a)
dimsB = len(b)
if dimsB > dimsA:
@ -101,13 +101,13 @@ def broadcast_inplace(a: List[int], b: List[int]):
return _copy(a)
def expand(self: List[int], sizes: List[int]):
def expand(self: list[int], sizes: list[int]):
assert len(sizes) >= len(self)
ndim = len(sizes)
tensor_dim = len(self)
if ndim == 0:
return _copy(sizes)
out: List[int] = []
out: list[int] = []
for i in range(ndim):
offset = ndim - 1 - i
dim = tensor_dim - 1 - offset
@ -123,11 +123,11 @@ def expand(self: List[int], sizes: List[int]):
return out
def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
def expand_one_unused(self: list[int], sizes: list[int], inp0: Any):
return expand(self, sizes)
def infer_size_impl(shape: List[int], numel: int) -> List[int]:
def infer_size_impl(shape: list[int], numel: int) -> list[int]:
newsize = 1
infer_dim: Optional[int] = None
for dim in range(len(shape)):
@ -150,27 +150,27 @@ def infer_size_impl(shape: List[int], numel: int) -> List[int]:
return out
def numel(sizes: List[int]):
def numel(sizes: list[int]):
numel = 1
for elem in sizes:
numel *= elem
return numel
def view(self: List[int], sizes: List[int]):
def view(self: list[int], sizes: list[int]):
return infer_size_impl(sizes, numel(self))
def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False):
return view(self, sizes)
def sum_mean_dim(
self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any
):
out: List[int] = []
out: list[int] = []
if opt_dims is None or len(opt_dims) == 0:
dims: List[int] = list(range(len(self)))
dims: list[int] = list(range(len(self)))
else:
dims = opt_dims
@ -187,7 +187,7 @@ def sum_mean_dim(
return out
def max_dim(self: List[int], dim: int, keep_dim: bool):
def max_dim(self: list[int], dim: int, keep_dim: bool):
out = sum_mean_dim(self, [dim], keep_dim, None)
return out, out
@ -239,7 +239,7 @@ def pooling_output_shape(
def pool2d_shape_check(
input: List[int],
input: list[int],
kH: int,
kW: int,
dH: int,
@ -273,11 +273,11 @@ def pool2d_shape_check(
def max_pool2d(
input: List[int],
kernel_size: List[int],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
):
assert (
@ -343,11 +343,11 @@ def max_pool2d(
def max_pool2d_with_indices(
input: List[int],
kernel_size: List[int],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
kernel_size: list[int],
stride: list[int],
padding: list[int],
dilation: list[int],
ceil_mode: bool,
):
out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
@ -355,11 +355,11 @@ def max_pool2d_with_indices(
def upsample_nearest2d(
input: List[int],
output_size: Optional[List[int]],
scale_factors: Optional[List[float]],
input: list[int],
output_size: Optional[list[int]],
scale_factors: Optional[list[float]],
):
out: List[int] = []
out: list[int] = []
out.append(input[0])
out.append(input[1])
@ -385,7 +385,7 @@ def upsample_nearest2d(
return out
def mm(self: List[int], mat2: List[int]):
def mm(self: list[int], mat2: list[int]):
assert len(self) == 2, "self must be a matrix"
assert len(mat2) == 2, "mat2 must be a matrix"
@ -393,37 +393,37 @@ def mm(self: List[int], mat2: List[int]):
return [self[0], mat2[1]]
def dot(self: List[int], tensor: List[int]):
def dot(self: list[int], tensor: list[int]):
assert len(self) == 1 and len(tensor) == 1
assert self[0] == tensor[0]
out: List[int] = []
out: list[int] = []
return out
def mv(self: List[int], vec: List[int]):
def mv(self: list[int], vec: list[int]):
assert len(self) == 2 and len(vec) == 1
assert self[1] == vec[0]
# TODO: return self
return [self[0]]
def unsqueeze(li: List[int], dim: int):
def unsqueeze(li: list[int], dim: int):
dim = maybe_wrap_dim(dim, len(li) + 1)
out = _copy(li)
out.insert(dim, 1)
return out
def squeeze_nodim(li: List[int]):
out: List[int] = []
def squeeze_nodim(li: list[int]):
out: list[int] = []
for i in range(len(li)):
if li[i] != 1:
out.append(li[i])
return out
def squeeze(li: List[int], dim: int):
out: List[int] = []
def squeeze(li: list[int], dim: int):
out: list[int] = []
wrapped_dim = maybe_wrap_dim(dim, len(li))
for i in range(len(li)):
if i == wrapped_dim:
@ -434,13 +434,13 @@ def squeeze(li: List[int], dim: int):
return out
def squeeze_dims(li: List[int], dims: List[int]):
def squeeze_dims(li: list[int], dims: list[int]):
if len(dims) == 0:
return li
wrapped_dims = _copy(dims)
for i in range(len(dims)):
wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
result: List[int] = []
result: list[int] = []
for i in range(len(li)):
if li[i] == 1:
if i not in wrapped_dims:
@ -450,12 +450,12 @@ def squeeze_dims(li: List[int], dims: List[int]):
return result
def index_select(self: List[int], dim: int, index: List[int]):
def index_select(self: list[int], dim: int, index: list[int]):
dim = maybe_wrap_dim(dim, len(self))
numel = multiply_integers(index)
assert len(index) <= 1
assert dim == 0 or dim < len(self)
result_size: List[int] = []
result_size: list[int] = []
for i in range(len(self)):
if dim == i:
result_size.append(numel)
@ -465,8 +465,8 @@ def index_select(self: List[int], dim: int, index: List[int]):
def embedding(
weight: List[int],
indices: List[int],
weight: list[int],
indices: list[int],
padding_idx: int = -1,
scale_grad_by_freq: bool = False,
sparse: bool = False,
@ -484,7 +484,7 @@ def max_int():
def slice(
self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int
):
ndim = len(self)
assert ndim != 0
@ -512,12 +512,12 @@ def slice(
return out
def check_cat_no_zero_dim(tensors: List[List[int]]):
def check_cat_no_zero_dim(tensors: list[list[int]]):
for tensor in tensors:
assert len(tensor) > 0
def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]):
out_dim: Optional[int] = None
for size in tensor_sizes:
if not (len(size) == 1 and size[0] == 0):
@ -528,12 +528,12 @@ def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
return out_dim
def should_skip(tensor: List[int]):
def should_skip(tensor: list[int]):
return numel(tensor) == 0 and len(tensor) == 1
def check_cat_shape_except_dim(
first: List[int], second: List[int], dimension: int, index: int
first: list[int], second: list[int], dimension: int, index: int
):
first_dims = len(first)
second_dims = len(second)
@ -545,11 +545,11 @@ def check_cat_shape_except_dim(
), "Sizes of tensors must match except in dimension"
def cat(tensors: List[List[int]], dim: int):
def cat(tensors: list[list[int]], dim: int):
check_cat_no_zero_dim(tensors)
dim = legacy_cat_wrap_dim(dim, tensors)
assert len(tensors) > 0
not_skipped_tensor: Optional[List[int]] = None
not_skipped_tensor: Optional[list[int]] = None
for tensor in tensors:
if not should_skip(tensor):
not_skipped_tensor = tensor
@ -569,15 +569,15 @@ def cat(tensors: List[List[int]], dim: int):
return result_size
def stack(tensors: List[List[int]], dim: int):
unsqueezed_tensors: List[List[int]] = []
def stack(tensors: list[list[int]], dim: int):
unsqueezed_tensors: list[list[int]] = []
for tensor in tensors:
unsqueezed = unsqueeze(tensor, dim)
unsqueezed_tensors.append(unsqueezed)
return cat(unsqueezed_tensors, dim)
def select(self: List[int], dim: int, index: int):
def select(self: list[int], dim: int, index: int):
ndim = len(self)
assert ndim != 0
dim = maybe_wrap_dim(dim, ndim)
@ -585,14 +585,14 @@ def select(self: List[int], dim: int, index: int):
assert not (index < -size or index >= size)
if index < 0:
index += size
out: List[int] = []
out: list[int] = []
for i in range(ndim):
if i != dim:
out.append(self[i])
return out
def matmul(tensor1: List[int], tensor2: List[int]):
def matmul(tensor1: list[int], tensor2: list[int]):
dim_tensor1 = len(tensor1)
dim_tensor2 = len(tensor2)
if dim_tensor1 == 1 and dim_tensor2 == 1:
@ -607,12 +607,12 @@ def matmul(tensor1: List[int], tensor2: List[int]):
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
# we track m1 vs m2 separately even though they must match for nicer error messages
n = tensor1[-2] if dim_tensor1 > 1 else 1
batch_tensor1: List[int] = []
batch_tensor1: list[int] = []
# TODO: handling of slice
for i in range(dim_tensor1 - 2):
batch_tensor1.append(tensor1[i])
p = tensor2[-1]
batch_tensor2: List[int] = []
batch_tensor2: list[int] = []
# TODO: handling of slice
for i in range(dim_tensor2 - 2):
batch_tensor2.append(tensor2[i])
@ -633,11 +633,11 @@ def matmul(tensor1: List[int], tensor2: List[int]):
assert False, "both arguments to matmul need to be at least 1D"
def t(self: List[int]):
def t(self: list[int]):
assert len(self) <= 2
self_len = len(self)
if self_len == 0:
out: List[int] = []
out: list[int] = []
return out
elif self_len == 1:
return [self[0]]
@ -645,13 +645,13 @@ def t(self: List[int]):
return [self[1], self[0]]
def transpose(self: List[int], dim0: int, dim1: int):
def transpose(self: list[int], dim0: int, dim1: int):
ndims = len(self)
dim0 = maybe_wrap_dim(dim0, ndims)
dim1 = maybe_wrap_dim(dim1, ndims)
if dim0 == dim1:
return _copy(self)
out: List[int] = []
out: list[int] = []
for i in range(ndims):
if i == dim0:
out.append(self[dim1])
@ -662,18 +662,18 @@ def transpose(self: List[int], dim0: int, dim1: int):
return out
def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
def linear(input: list[int], weight: list[int], bias: Optional[list[int]]):
out = matmul(input, t(weight))
if bias is not None:
assert broadcast(bias, out) == out
return out
def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any):
return broadcast(self, mm(mat1, mat2))
def check_non_negative(array: List[int]) -> bool:
def check_non_negative(array: list[int]) -> bool:
# TODO: look into rewriting with early return and getting loop unrolling to fire
non_negative = False
for val in array:
@ -683,12 +683,12 @@ def check_non_negative(array: List[int]) -> bool:
def check_shape_forward(
input: List[int],
weight_sizes: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight_sizes: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
):
k = len(input)
@ -714,12 +714,12 @@ def check_shape_forward(
def conv_output_size(
input_size: List[int],
weight_size: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input_size: list[int],
weight_size: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
):
check_shape_forward(
@ -728,7 +728,7 @@ def conv_output_size(
has_dilation = len(dilation) > 0
dim = len(input_size)
output_size: List[int] = []
output_size: list[int] = []
input_batch_size_dim = 0
weight_output_channels_dim = 0
output_size.append(input_size[input_batch_size_dim])
@ -744,12 +744,12 @@ def conv_output_size(
def conv1d(
input: List[int],
weight: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
):
assert len(weight) == 3
@ -758,12 +758,12 @@ def conv1d(
def conv2d(
input: List[int],
weight: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
):
assert len(weight) == 4
@ -772,25 +772,25 @@ def conv2d(
def conv_backwards(
grad_output: List[int],
input: List[int],
weight: List[int],
biases: Optional[List[int]],
grad_output: list[int],
input: list[int],
weight: list[int],
biases: Optional[list[int]],
):
# Bias gradient is always generated regardess of if biases is supplied
return _copy(input), _copy(weight), [grad_output[1]]
def conv_transpose2d_input(
input: List[int],
weight: List[int],
bias: Optional[List[int]] = None,
stride: Optional[List[int]] = None,
padding: Optional[List[int]] = None,
output_padding: Optional[List[int]] = None,
input: list[int],
weight: list[int],
bias: Optional[list[int]] = None,
stride: Optional[list[int]] = None,
padding: Optional[list[int]] = None,
output_padding: Optional[list[int]] = None,
groups: int = 1,
dilation: Optional[List[int]] = None,
) -> List[int]:
dilation: Optional[list[int]] = None,
) -> list[int]:
if stride is None:
stride = [1, 1]
if padding is None:
@ -801,7 +801,7 @@ def conv_transpose2d_input(
dilation = [1, 1]
has_dilation = len(dilation) > 0
dim = len(input)
output_size: List[int] = []
output_size: list[int] = []
input_batch_size_dim = 0
weight_output_channels_dim = 1
output_size.append(input[input_batch_size_dim])
@ -821,20 +821,20 @@ def conv_transpose2d_input(
def conv_forwards(
input: List[int],
weight: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
transposed: bool,
output_padding: List[int],
output_padding: list[int],
groups: int,
) -> List[int]:
) -> list[int]:
has_dilation = len(dilation) > 0
has_output_padding = len(output_padding) > 0
dim = len(input)
output_size: List[int] = []
output_size: list[int] = []
input_batch_size_dim = 0
weight_output_channels_dim = 1 if transposed else 0
output_size.append(input[input_batch_size_dim])
@ -864,20 +864,20 @@ def conv_forwards(
def _conv_forwards(
input: List[int],
weight: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
transposed: bool,
output_padding: List[int],
output_padding: list[int],
groups: int,
benchmark: bool,
deterministic: bool,
cudnn_enabled: bool,
allow_tf32: bool,
) -> List[int]:
) -> list[int]:
return conv_forwards(
input,
weight,
@ -892,29 +892,29 @@ def _conv_forwards(
def batch_norm(
input: List[int],
weight: Optional[List[int]],
bias: Optional[List[int]],
running_mean: Optional[List[int]],
running_var: Optional[List[int]],
input: list[int],
weight: Optional[list[int]],
bias: Optional[list[int]],
running_mean: Optional[list[int]],
running_var: Optional[list[int]],
training: bool,
momentum: float,
eps: float,
cudnn_enabled: bool,
):
out: List[int] = []
out: list[int] = []
for elem in input:
out.append(elem)
return out
def conv3d(
input: List[int],
weight: List[int],
bias: Optional[List[int]],
stride: List[int],
padding: List[int],
dilation: List[int],
input: list[int],
weight: list[int],
bias: Optional[list[int]],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
):
assert len(weight) == 5
@ -935,11 +935,11 @@ def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
def zero_dim_tensor(input: Any):
out: List[int] = []
out: list[int] = []
return out
def multiply_integers(li: List[int]):
def multiply_integers(li: list[int]):
out = 1
for elem in li:
out = out * elem
@ -970,11 +970,11 @@ def arange_start_step(
return [int(math.ceil((end - start) / step))]
def permute(input: List[int], dims: List[int]):
def permute(input: list[int], dims: list[int]):
assert len(input) == len(dims)
ndim = len(dims)
seen_dims: List[int] = []
newSizes: List[int] = []
seen_dims: list[int] = []
newSizes: list[int] = []
for i in range(ndim):
dim = maybe_wrap_dim(dims[i], ndim)
seen_dims.append(dim)
@ -985,12 +985,12 @@ def permute(input: List[int], dims: List[int]):
return newSizes
def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]:
self_dim = len(self)
if self_dim <= 1:
return self
normalized_src: List[int] = []
normalized_dst: List[int] = []
normalized_src: list[int] = []
normalized_dst: list[int] = []
for i in range(len(source)):
normalized_src.append(maybe_wrap_dim(source[i], self_dim))
normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
@ -1003,8 +1003,8 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[
src_dims[normalized_src[i]] = -1
dst_dims[normalized_dst[i]] = -1
source_dims: List[int] = []
destination_dims: List[int] = []
source_dims: list[int] = []
destination_dims: list[int] = []
for ele in src_dims:
if ele != -1:
source_dims.append(ele)
@ -1018,7 +1018,7 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[
return permute(self, order)
def flatten(input: List[int], start_dim: int, end_dim: int):
def flatten(input: list[int], start_dim: int, end_dim: int):
start_dim = maybe_wrap_dim(start_dim, len(input))
end_dim = maybe_wrap_dim(end_dim, len(input))
assert start_dim <= end_dim
@ -1026,7 +1026,7 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
return [1]
if start_dim == end_dim:
# TODO: return self
out: List[int] = []
out: list[int] = []
for elem in input:
out.append(elem)
return out
@ -1035,7 +1035,7 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
slice_numel *= input[i]
# TODO: use slicing when slice optimization has landed
# slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
shape: List[int] = []
shape: list[int] = []
for i in range(start_dim):
shape.append(input[i])
shape.append(slice_numel)
@ -1044,17 +1044,17 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
return shape
def nonzero_lower_bound(input: List[int]):
def nonzero_lower_bound(input: list[int]):
return [0, len(input)]
def nonzero_upper_bound(input: List[int]):
def nonzero_upper_bound(input: list[int]):
return [numel(input), len(input)]
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
def _reduce_along_dim(self: list[int], dim: int, keepdim: bool):
dim = maybe_wrap_dim(dim, len(self))
out: List[int] = []
out: list[int] = []
for i, self_dim in enumerate(self):
if i == dim:
if keepdim:
@ -1065,14 +1065,14 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
def argmax(
self: List[int], dim: Optional[int] = None, keepdim: bool = False
) -> List[int]:
self: list[int], dim: Optional[int] = None, keepdim: bool = False
) -> list[int]:
if dim is None:
return []
return _reduce_along_dim(self, dim, keepdim)
def bmm(self: List[int], mat2: List[int]) -> List[int]:
def bmm(self: list[int], mat2: list[int]) -> list[int]:
assert len(self) == 3, "bmm only supports 3D tensors"
assert len(mat2) == 3, "bmm only supports 3D tensors"
assert self[0] == mat2[0], "mismatching batch dimension"
@ -1080,13 +1080,13 @@ def bmm(self: List[int], mat2: List[int]) -> List[int]:
return [self[0], self[1], mat2[2]]
def _shape_as_tensor(self: List[int]) -> List[int]:
def _shape_as_tensor(self: list[int]) -> list[int]:
return [len(self)]
def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
if len(self) == 0:
result: List[int] = []
result: list[int] = []
else:
assert (
k <= self[dim]
@ -1097,8 +1097,8 @@ def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
def nll_loss_forward(
self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
) -> Tuple[List[int], List[int]]:
self: list[int], target: list[int], weight: Optional[list[int]], reduction: int
) -> tuple[list[int], list[int]]:
# This is taken shamelessly from the meta function in LossNLL.cpp
self_dim = len(self)
target_dim = len(target)
@ -1107,7 +1107,7 @@ def nll_loss_forward(
no_batch_dim = self_dim == 1 and target_dim == 0
assert no_batch_dim or (self[0] == target[0])
n_classes = self[-1]
scalar_shape: List[int] = []
scalar_shape: list[int] = []
assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
if reduction == 0 and self_dim == 2:
reduction_shape = [self[0]]
@ -1117,9 +1117,9 @@ def nll_loss_forward(
def native_layer_norm(
input: List[int], normalized_shape: List[int]
) -> Tuple[List[int], List[int], List[int]]:
reduction_shape: List[int] = []
input: list[int], normalized_shape: list[int]
) -> tuple[list[int], list[int], list[int]]:
reduction_shape: list[int] = []
num_unreduced_dimensions = len(input) - len(normalized_shape)
assert num_unreduced_dimensions >= 0
for i in range(num_unreduced_dimensions):
@ -1130,13 +1130,13 @@ def native_layer_norm(
def native_batch_norm(
input: List[int],
weight: Optional[List[int]],
bias: Optional[List[int]],
running_mean: Optional[List[int]],
running_var: Optional[List[int]],
input: list[int],
weight: Optional[list[int]],
bias: Optional[list[int]],
running_mean: Optional[list[int]],
running_var: Optional[list[int]],
training: bool,
) -> Tuple[List[int], List[int], List[int]]:
) -> tuple[list[int], list[int], list[int]]:
if training:
_size = [input[1]]
else:
@ -1145,24 +1145,24 @@ def native_batch_norm(
def _batch_norm_with_update(
input: List[int],
weight: Optional[List[int]],
bias: Optional[List[int]],
running_mean: Optional[List[int]],
running_var: Optional[List[int]],
) -> Tuple[List[int], List[int], List[int], List[int]]:
input: list[int],
weight: Optional[list[int]],
bias: Optional[list[int]],
running_mean: Optional[list[int]],
running_var: Optional[list[int]],
) -> tuple[list[int], list[int], list[int], list[int]]:
_size = [input[1]]
return _copy(input), _size, _size, [0]
def cross_entropy_loss(
self: List[int],
target: List[int],
weight: Optional[List[int]] = None,
self: list[int],
target: list[int],
weight: Optional[list[int]] = None,
reduction: int = 1,
ignore_index: int = -100,
label_smoothing: float = 0.0,
) -> List[int]:
) -> list[int]:
result_shape = nll_loss_forward(self, target, weight, reduction)[0]
return result_shape
@ -1187,9 +1187,9 @@ def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[in
"""
ScriptFn = torch._C.ScriptFunction
shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
script_func_map: Dict[Callable, ScriptFn] = {}
shape_compute_graph_mapping: dict[str, ScriptFn] = {}
bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {}
script_func_map: dict[Callable, ScriptFn] = {}
def process_func(func: Callable):

View File

@ -6,9 +6,10 @@ This module stores various pieces of Python-global state relating to the JIT.
This is not intended to be imported directly; please the exposed
functionalities in `torch.jit`.
"""
import os
import weakref
from typing import Any, Dict, Type
from typing import Any
import torch
@ -62,8 +63,8 @@ _python_cu = torch._C.CompilationUnit()
# python class => ScriptClass mapping
_script_classes: Dict[Type[Any], Type[Any]] = {}
_name_to_pyclass: Dict[str, Type[Any]] = {}
_script_classes: dict[type[Any], type[Any]] = {}
_name_to_pyclass: dict[str, type[Any]] = {}
def _add_script_class(python_class, script_class):

View File

@ -17,7 +17,7 @@ import os
import re
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
from typing import Any, Callable, Optional, TypeVar
from typing_extensions import ParamSpec
import torch
@ -70,7 +70,7 @@ def _unique_state_dict(module, keep_vars=False):
# as values, and deduplicate the params using Parameters and Buffers
state_dict = module.state_dict(keep_vars=True)
filtered_dict = type(state_dict)()
seen_ids: Set[int] = set()
seen_ids: set[int] = set()
for k, v in state_dict.items():
if id(v) in seen_ids:
continue
@ -112,7 +112,7 @@ class ONNXTracedModule(torch.nn.Module):
outs = []
def wrapper(*args):
in_args: List[torch.Tensor] = []
in_args: list[torch.Tensor] = []
for i in range(len(in_vars)):
if not isinstance(args[i], torch.Tensor):
raise RuntimeError("Expected Tensor argument")
@ -960,6 +960,7 @@ def trace(
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -968,6 +969,7 @@ def trace(
def forward(self, x):
return self.conv(x)
n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
@ -1112,7 +1114,7 @@ def trace(
return traced_func
_trace_module_map: Optional[Dict[Any, Any]] = None
_trace_module_map: Optional[dict[Any, Any]] = None
def trace_module(
@ -1176,6 +1178,7 @@ def trace_module(
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -1202,7 +1205,7 @@ def trace_module(
# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
module = torch.jit.trace_module(n, inputs)
"""
@ -1226,7 +1229,7 @@ def trace_module(
old_module_map = torch.jit._trace._trace_module_map
try:
trace_module_map: Dict[Any, Any] = {}
trace_module_map: dict[Any, Any] = {}
def register_submods(mod, prefix):
for name, child in mod.named_children():

View File

@ -8,7 +8,6 @@ import re
import typing
import warnings
from textwrap import dedent
from typing import Type
import torch
from torch._C import (
@ -350,7 +349,7 @@ def try_real_annotations(fn, loc):
# Finds common type for enum values belonging to an Enum class. If not all
# values have the same type, AnyType is returned.
def get_enum_value_type(e: Type[enum.Enum], loc):
def get_enum_value_type(e: type[enum.Enum], loc):
enum_values: List[enum.Enum] = list(e)
if not enum_values:
raise ValueError(f"No enum values defined for: '{e.__class__}'")

View File

@ -4,7 +4,6 @@ import dataclasses
import inspect
import re
import string
import sys
from collections import namedtuple
from textwrap import dedent
from typing import List, Tuple # noqa: F401
@ -1128,10 +1127,7 @@ class ExprBuilder(Builder):
return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
elif sub_type is ast.ExtSlice:
return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
elif sys.version_info >= (
3,
9,
): # In Python3.9 array indicies are not wrapped in ast.Index
else: # In Python3.9 array indicies are not wrapped in ast.Index
if sub_type is ast.Tuple:
# N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
indices = []
@ -1150,8 +1146,6 @@ class ExprBuilder(Builder):
indices.append(tup)
return Subscript(base, indices)
return Subscript(base, [build_expr(ctx, expr.slice)])
else: # Ellipsis (can only happen in Python 2)
raise NotSupportedError(base.range(), "ellipsis is not supported")
@staticmethod
def build_List(ctx, expr):

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import List
from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph
@ -20,7 +19,7 @@ def format_bytecode(table):
return formatted_table
def generate_upgraders_bytecode() -> List:
def generate_upgraders_bytecode() -> list:
yaml_content = []
upgraders_graph_map = _generate_upgraders_graph()
for upgrader_name, upgrader_graph in upgraders_graph_map.items():

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from textwrap import dedent
from typing import Any, Dict
from typing import Any
import torch.jit
@ -37,7 +37,7 @@ def _gen_unsupported_methods_properties():
sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower())
for attr in sorted_tensor_attrs:
funcs_str = funcs_template.format(op=attr)
scope: Dict[str, Any] = {}
scope: dict[str, Any] = {}
execWrapper(funcs_str, globals(), scope)
try:
torch.jit.CompilationUnit(funcs_str)