mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply Ruff fixes and pyupgrade to torch/jit (#144208)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/144208 Approved by: https://github.com/davidberard98
This commit is contained in:
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch._C
|
||||
|
||||
|
@ -50,11 +50,17 @@ def fork(func, *args, **kwargs):
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
def foo(a : Tensor, b : int) -> Tensor:
|
||||
|
||||
|
||||
def foo(a: Tensor, b: int) -> Tensor:
|
||||
return a + b
|
||||
|
||||
|
||||
def bar(a):
|
||||
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
|
||||
fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
|
||||
return torch.jit.wait(fut)
|
||||
|
||||
|
||||
script_bar = torch.jit.script(bar)
|
||||
input = torch.tensor(2)
|
||||
# only the scripted version executes asynchronously
|
||||
@ -69,16 +75,23 @@ def fork(func, *args, **kwargs):
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class AddMod(torch.nn.Module):
|
||||
def forward(self, a: Tensor, b : int):
|
||||
def forward(self, a: Tensor, b: int):
|
||||
return a + b
|
||||
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(self).__init__()
|
||||
self.mod = AddMod()
|
||||
|
||||
def forward(self, input):
|
||||
fut = torch.jit.fork(self.mod, a, b=2)
|
||||
return torch.jit.wait(fut)
|
||||
|
||||
|
||||
input = torch.tensor(2)
|
||||
mod = Mod()
|
||||
assert mod(input) == torch.jit.script(mod).forward(input)
|
||||
|
@ -3,7 +3,7 @@ import cmath
|
||||
import math
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
@ -16,7 +16,7 @@ from torch.nn.modules.utils import (
|
||||
)
|
||||
|
||||
|
||||
_builtin_table: Optional[Dict[int, str]] = None
|
||||
_builtin_table: Optional[dict[int, str]] = None
|
||||
|
||||
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
|
||||
|
||||
|
@ -5,7 +5,7 @@ import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List
|
||||
from typing import Callable
|
||||
|
||||
from torch._jit_internal import FAKE_FILENAME_PREFIX, is_optional
|
||||
from torch._sources import ParsedDef, SourceContext
|
||||
@ -15,7 +15,7 @@ def _get_fake_filename(cls, method_name):
|
||||
return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
|
||||
|
||||
|
||||
def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
|
||||
def compose_fn(cls, name: str, body_lines: list[str], signature: str) -> ParsedDef:
|
||||
body = "\n".join(f" {b}" for b in body_lines)
|
||||
decl = f"def {name}{signature}:\n{body}"
|
||||
|
||||
@ -59,7 +59,7 @@ def synthesize__init__(cls) -> ParsedDef:
|
||||
|
||||
# Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
|
||||
# see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
|
||||
init_vars: List[str] = []
|
||||
init_vars: list[str] = []
|
||||
params = []
|
||||
for name, param in signature.parameters.items():
|
||||
ann = param.annotation
|
||||
@ -144,7 +144,7 @@ def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
|
||||
|
||||
|
||||
def synthesize_comparison(
|
||||
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]
|
||||
cls, name: str, allow_eq: bool, raise_on_none: bool, inner: list[str]
|
||||
) -> ParsedDef:
|
||||
body = []
|
||||
for field in dataclasses.fields(cls):
|
||||
@ -177,7 +177,7 @@ def synthesize_comparison(
|
||||
)
|
||||
|
||||
|
||||
DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
|
||||
DATACLASS_MAGIC_METHODS: dict[str, Callable] = {
|
||||
"__init__": synthesize__init__,
|
||||
"__repr__": synthesize__repr__,
|
||||
"__hash__": synthesize__hash__,
|
||||
|
@ -6,14 +6,14 @@ from torch import Tensor
|
||||
aten = torch.ops.aten
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, Optional, Set, TypeVar
|
||||
from typing import Callable, Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from torch.types import Number
|
||||
|
||||
|
||||
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
|
||||
function_name_set: Set[str] = set()
|
||||
decomposition_table: dict[str, torch.jit.ScriptFunction] = {}
|
||||
function_name_set: set[str] = set()
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
@ -65,7 +65,7 @@ def signatures_match(decomposition_sig, torch_op_sig):
|
||||
|
||||
def register_decomposition(
|
||||
aten_op: torch._ops.OpOverload,
|
||||
registry: Optional[Dict[str, torch.jit.ScriptFunction]] = None,
|
||||
registry: Optional[dict[str, torch.jit.ScriptFunction]] = None,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
nonlocal registry
|
||||
@ -99,12 +99,12 @@ def register_decomposition(
|
||||
@register_decomposition(aten.var.correction)
|
||||
def var_decomposition(
|
||||
input: Tensor,
|
||||
dim: Optional[List[int]] = None,
|
||||
dim: Optional[list[int]] = None,
|
||||
correction: Optional[Number] = None,
|
||||
keepdim: bool = False,
|
||||
) -> Tensor:
|
||||
if dim is None:
|
||||
dim_i: List[int] = []
|
||||
dim_i: list[int] = []
|
||||
dim = dim_i
|
||||
|
||||
if isinstance(dim, (tuple, list)) and len(dim) == 0:
|
||||
|
@ -5,14 +5,14 @@ This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.jit._script import RecursiveScriptModule, ScriptModule
|
||||
|
||||
|
||||
def freeze(
|
||||
mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True
|
||||
mod, preserved_attrs: Optional[list[str]] = None, optimize_numerics: bool = True
|
||||
):
|
||||
r"""Freeze ScriptModule, inline submodules, and attributes as constants.
|
||||
|
||||
@ -124,7 +124,7 @@ def freeze(
|
||||
|
||||
|
||||
def run_frozen_optimizations(
|
||||
mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None
|
||||
mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None
|
||||
):
|
||||
r"""
|
||||
Run a series of optimizations looking for patterns that occur in frozen graphs.
|
||||
@ -155,9 +155,12 @@ def run_frozen_optimizations(
|
||||
Example (Freezing a module with Conv->Batchnorm)
|
||||
.. code-block:: python
|
||||
import torch
|
||||
|
||||
in_channels, out_channels = 3, 32
|
||||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||
conv = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=True
|
||||
)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
mod = torch.nn.Sequential(conv, bn)
|
||||
# set optimize to False here, by default freezing runs run_frozen_optimizations
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
|
||||
@ -180,7 +183,7 @@ def run_frozen_optimizations(
|
||||
|
||||
|
||||
def optimize_for_inference(
|
||||
mod: ScriptModule, other_methods: Optional[List[str]] = None
|
||||
mod: ScriptModule, other_methods: Optional[list[str]] = None
|
||||
) -> ScriptModule:
|
||||
"""
|
||||
Perform a set of optimization passes to optimize a model for the purposes of inference.
|
||||
@ -202,9 +205,12 @@ def optimize_for_inference(
|
||||
Example (optimizing a module with Conv->Batchnorm)::
|
||||
|
||||
import torch
|
||||
|
||||
in_channels, out_channels = 3, 32
|
||||
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
|
||||
conv = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, bias=True
|
||||
)
|
||||
bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
mod = torch.nn.Sequential(conv, bn)
|
||||
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
|
||||
assert "batch_norm" not in str(frozen_mod.graph)
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -106,7 +105,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs):
|
||||
|
||||
# graph_executor_states for differentiable node
|
||||
fw_states = eps[0].code.differentiable_op_executor_states()
|
||||
diff_nodes: List[torch._C.Node] = []
|
||||
diff_nodes: list[torch._C.Node] = []
|
||||
for n in graph.nodes():
|
||||
_get_differentiable_graph_node(n, diff_nodes)
|
||||
|
||||
@ -128,7 +127,7 @@ def _script_method_graph_for(self, parent, *args, **kwargs):
|
||||
return last_executed_optimized_graph()
|
||||
|
||||
|
||||
def set_fusion_strategy(strategy: List[Tuple[str, int]]):
|
||||
def set_fusion_strategy(strategy: list[tuple[str, int]]):
|
||||
"""Set the type and number of specializations that can occur during fusion.
|
||||
|
||||
Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
|
||||
|
@ -1,5 +1,5 @@
|
||||
from types import TracebackType
|
||||
from typing import Optional, Type, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,7 +20,7 @@ class _InsertPoint:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
|
@ -3,9 +3,10 @@ import inspect
|
||||
import sys
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from types import CodeType
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -94,7 +95,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
# A dictionary keeping all collected CallTrace
|
||||
# key is fully qualified name of called function
|
||||
# value is list of all CallTrace
|
||||
self.trace_records: Dict[str, list] = defaultdict(list)
|
||||
self.trace_records: dict[str, list] = defaultdict(list)
|
||||
|
||||
def add(self, traces: Iterable[CallTrace]):
|
||||
for t in traces:
|
||||
@ -106,10 +107,10 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
qualified_name: str,
|
||||
qualname_prefix: Optional[str] = None,
|
||||
limit: int = 2000,
|
||||
) -> List[CallTraceThunk]:
|
||||
) -> list[CallTraceThunk]:
|
||||
return self.trace_records[qualified_name]
|
||||
|
||||
def analyze(self, qualified_name: str) -> Dict:
|
||||
def analyze(self, qualified_name: str) -> dict:
|
||||
# Analyze the types for the given module
|
||||
# and create a dictionary of all the types
|
||||
# for arguments.
|
||||
@ -120,7 +121,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
all_args[arg].add(arg_type)
|
||||
return all_args
|
||||
|
||||
def consolidate_types(self, qualified_name: str) -> Dict:
|
||||
def consolidate_types(self, qualified_name: str) -> dict:
|
||||
all_args = self.analyze(qualified_name)
|
||||
# If there are more types for an argument,
|
||||
# then consolidate the type to `Any` and replace the entry
|
||||
@ -137,7 +138,7 @@ if _IS_MONKEYTYPE_INSTALLED:
|
||||
all_args[arg] = get_type(types[0])
|
||||
return all_args
|
||||
|
||||
def get_args_types(self, qualified_name: str) -> Dict:
|
||||
def get_args_types(self, qualified_name: str) -> dict:
|
||||
return self.consolidate_types(qualified_name)
|
||||
|
||||
class JitTypeTraceConfig(monkeytype.config.Config):
|
||||
|
@ -5,14 +5,14 @@ This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import TensorType
|
||||
from torch._C import Graph
|
||||
|
||||
|
||||
def apply_input_props_using_example(graph: Graph, example_input: List[Any]) -> None:
|
||||
def apply_input_props_using_example(graph: Graph, example_input: list[Any]) -> None:
|
||||
"""
|
||||
Applies properties for each tensor in the graph inputs
|
||||
using the example supplied.
|
||||
|
@ -6,7 +6,6 @@ import sys
|
||||
import textwrap
|
||||
import types
|
||||
import warnings
|
||||
from typing import Dict, List, Set, Type
|
||||
|
||||
import torch
|
||||
import torch._jit_internal as _jit_internal
|
||||
@ -421,8 +420,8 @@ def infer_concrete_type_builder(nn_module, share_types=True):
|
||||
|
||||
|
||||
class ConcreteTypeStore:
|
||||
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
|
||||
methods_compiled: Set[torch._C.ConcreteModuleType]
|
||||
type_store: dict[type[Module], list[torch._C.ConcreteModuleType]]
|
||||
methods_compiled: set[torch._C.ConcreteModuleType]
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Python module type => List[ConcreteModuleType)]
|
||||
@ -766,7 +765,7 @@ def get_overload_annotations(mod, jit_ignored_properties):
|
||||
def get_overload_name_mapping(overload_info):
|
||||
# Same format as __overloads__
|
||||
# original function => [overload names]
|
||||
overload_name_mappings: Dict[str, List[str]] = {}
|
||||
overload_name_mappings: dict[str, list[str]] = {}
|
||||
for orig_fn, overloads in overload_info.items():
|
||||
original_name = orig_fn.__name__
|
||||
if original_name not in overload_name_mappings:
|
||||
@ -836,7 +835,7 @@ def infer_methods_to_compile(nn_module):
|
||||
check_module_initialized(nn_module)
|
||||
ignored_properties = jit_ignored_properties(nn_module)
|
||||
|
||||
methods: List[str] = []
|
||||
methods: list[str] = []
|
||||
if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn(
|
||||
nn_module.forward
|
||||
):
|
||||
@ -873,7 +872,7 @@ def infer_methods_to_compile(nn_module):
|
||||
|
||||
# Unique the methods. We don't want to use a set to store the methods because it
|
||||
# introduces non-determinism to compile order.
|
||||
uniquer: Set[str] = set()
|
||||
uniquer: set[str] = set()
|
||||
uniqued_methods = []
|
||||
for name in filtered_methods:
|
||||
if name in uniquer:
|
||||
@ -888,7 +887,7 @@ def infer_methods_to_compile(nn_module):
|
||||
def get_hook_stubs(nn_module):
|
||||
"""Return forward hook and pre_hook ScriptModuleStubs."""
|
||||
check_module_initialized(nn_module)
|
||||
hook_map: Dict = {}
|
||||
hook_map: dict = {}
|
||||
|
||||
hook_stubs = []
|
||||
for hook in nn_module._forward_hooks.values():
|
||||
|
@ -6,6 +6,7 @@ This module contains functionality to support the JIT's scripting frontend, nota
|
||||
This is not intended to be imported directly; please use the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import enum
|
||||
@ -13,7 +14,7 @@ import functools
|
||||
import inspect
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._jit_internal as _jit_internal
|
||||
@ -277,12 +278,12 @@ class OrderedModuleDict(OrderedDictWrapper):
|
||||
class ScriptMeta(type):
|
||||
def __init__(cls, name, bases, attrs): # noqa: B902
|
||||
# Aggregate all the ScriptMethods and constants from superclasses
|
||||
cls._methods: Dict[str, Any] = {}
|
||||
cls._methods: dict[str, Any] = {}
|
||||
cls._constants_set = set(getattr(cls, "__constants__", ()))
|
||||
for base in reversed(bases):
|
||||
for k, v in getattr(base, "_methods", {}).items():
|
||||
cls._methods[k] = v
|
||||
base_constants: Set = getattr(base, "_constants_set", set())
|
||||
base_constants: set = getattr(base, "_constants_set", set())
|
||||
cls._constants_set = cls._constants_set.union(base_constants)
|
||||
|
||||
# find all the script methods of the current class
|
||||
@ -1020,7 +1021,9 @@ def call_prepare_scriptable_func_impl(obj, memo):
|
||||
if obj_id in memo:
|
||||
return memo[id(obj)]
|
||||
|
||||
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator]
|
||||
obj = (
|
||||
obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj
|
||||
) # type: ignore[operator]
|
||||
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
|
||||
# hierarchy when recursing below.
|
||||
memo[obj_id] = obj
|
||||
@ -1046,7 +1049,7 @@ def call_prepare_scriptable_func_impl(obj, memo):
|
||||
|
||||
|
||||
def call_prepare_scriptable_func(obj):
|
||||
memo: Dict[int, torch.nn.Module] = {}
|
||||
memo: dict[int, torch.nn.Module] = {}
|
||||
return call_prepare_scriptable_func_impl(obj, memo)
|
||||
|
||||
|
||||
@ -1089,7 +1092,7 @@ def _script_impl(
|
||||
optimize=None,
|
||||
_frames_up=0,
|
||||
_rcb=None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None,
|
||||
):
|
||||
global type_trace_db
|
||||
|
||||
@ -1118,7 +1121,7 @@ def _script_impl(
|
||||
if monkeytype_trace:
|
||||
monkeytype_config = JitTypeTraceConfig(type_trace_db)
|
||||
with monkeytype_trace(monkeytype_config):
|
||||
if isinstance(example_inputs, Dict):
|
||||
if isinstance(example_inputs, dict):
|
||||
# If the obj is an nn.Module or a class, then each method is
|
||||
# executed with the arguments provided in the example inputs.
|
||||
# example inputs here will be of type Dict(class.method, (arguments))
|
||||
@ -1127,7 +1130,7 @@ def _script_impl(
|
||||
for module, example_input in example_inputs.items():
|
||||
for example in example_input:
|
||||
module(*example)
|
||||
elif isinstance(example_inputs, List):
|
||||
elif isinstance(example_inputs, list):
|
||||
for examples in example_inputs:
|
||||
obj(*examples)
|
||||
else:
|
||||
@ -1148,7 +1151,11 @@ def _script_impl(
|
||||
obj, torch.jit._recursive.infer_methods_to_compile
|
||||
)
|
||||
else:
|
||||
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator]
|
||||
obj = (
|
||||
obj.__prepare_scriptable__()
|
||||
if hasattr(obj, "__prepare_scriptable__")
|
||||
else obj
|
||||
) # type: ignore[operator]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return create_script_dict(obj)
|
||||
@ -1220,7 +1227,7 @@ def script(
|
||||
optimize=None,
|
||||
_frames_up=0,
|
||||
_rcb=None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
example_inputs: Union[list[tuple], dict[Callable, list[tuple]], None] = None,
|
||||
):
|
||||
r"""Script the function.
|
||||
|
||||
@ -1374,6 +1381,7 @@ def script(
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -1387,6 +1395,7 @@ def script(
|
||||
# This function won't be compiled, so any
|
||||
# Python APIs can be used
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
def forward(self, input):
|
||||
@ -1394,6 +1403,7 @@ def script(
|
||||
self.python_only_fn(input)
|
||||
return input * 99
|
||||
|
||||
|
||||
scripted_module = torch.jit.script(MyModule())
|
||||
print(scripted_module.some_entry_point(torch.randn(2, 2)))
|
||||
print(scripted_module(torch.randn(2, 2)))
|
||||
@ -1619,14 +1629,14 @@ class _ScriptProfileColumn:
|
||||
self.header = header
|
||||
self.alignment = alignment
|
||||
self.offset = offset
|
||||
self.rows: Dict[int, Any] = {}
|
||||
self.rows: dict[int, Any] = {}
|
||||
|
||||
def add_row(self, lineno: int, value: Any):
|
||||
self.rows[lineno] = value
|
||||
|
||||
def materialize(self):
|
||||
max_length = len(self.header)
|
||||
rows: List[Tuple[int, str]] = []
|
||||
rows: list[tuple[int, str]] = []
|
||||
for key, value in self.rows.items():
|
||||
cell = str(value)
|
||||
rows.append((key, cell))
|
||||
@ -1643,13 +1653,13 @@ class _ScriptProfileColumn:
|
||||
|
||||
|
||||
class _ScriptProfileTable:
|
||||
def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
|
||||
def __init__(self, cols: list[_ScriptProfileColumn], source_range: list[int]):
|
||||
self.cols = cols
|
||||
self.source_range = source_range
|
||||
|
||||
def dump_string(self):
|
||||
outputs: List[str] = []
|
||||
cells: List[Tuple[str, Dict[int, str]]] = []
|
||||
outputs: list[str] = []
|
||||
cells: list[tuple[str, dict[int, str]]] = []
|
||||
header_buffer = ""
|
||||
for col in self.cols:
|
||||
header, rows = col.materialize()
|
||||
@ -1681,7 +1691,7 @@ class _ScriptProfile:
|
||||
self.profile.disable()
|
||||
|
||||
def dump_string(self) -> str:
|
||||
outputs: List[str] = []
|
||||
outputs: list[str] = []
|
||||
for source_stats in self.profile._dump_stats():
|
||||
source_ref = source_stats.source()
|
||||
source_lines = source_ref.text().splitlines()
|
||||
|
@ -165,7 +165,9 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
||||
|
||||
cu = torch._C.CompilationUnit()
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
|
||||
cpp_module = torch._C.import_ir_module(
|
||||
cu, os.fspath(f), map_location, _extra_files, _restore_shapes
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
cpp_module = torch._C.import_ir_module_from_buffer(
|
||||
cu, f.read(), map_location, _extra_files, _restore_shapes
|
||||
|
@ -24,11 +24,11 @@ number = Union[int, float]
|
||||
import torch
|
||||
|
||||
|
||||
def broadcast(a: List[int], b: List[int]):
|
||||
def broadcast(a: list[int], b: list[int]):
|
||||
dimsA = len(a)
|
||||
dimsB = len(b)
|
||||
ndim = max(dimsA, dimsB)
|
||||
expandedSizes: List[int] = []
|
||||
expandedSizes: list[int] = []
|
||||
|
||||
for i in range(ndim):
|
||||
offset = ndim - 1 - i
|
||||
@ -48,21 +48,21 @@ def broadcast(a: List[int], b: List[int]):
|
||||
return expandedSizes
|
||||
|
||||
|
||||
def broadcast_three(a: List[int], b: List[int], c: List[int]):
|
||||
def broadcast_three(a: list[int], b: list[int], c: list[int]):
|
||||
return broadcast(broadcast(a, b), c)
|
||||
|
||||
|
||||
def broadcast_one_three(a: List[int], b: Any, c: List[int]):
|
||||
def broadcast_one_three(a: list[int], b: Any, c: list[int]):
|
||||
return broadcast(a, c)
|
||||
|
||||
|
||||
def adaptive_avg_pool2d(self: List[int], out: List[int]):
|
||||
def adaptive_avg_pool2d(self: list[int], out: list[int]):
|
||||
assert len(out) == 2
|
||||
assert len(self) == 3 or len(self) == 4
|
||||
for i in range(1, len(self)):
|
||||
assert self[i] != 0
|
||||
|
||||
shape: List[int] = []
|
||||
shape: list[int] = []
|
||||
for i in range(0, len(self) - 2):
|
||||
shape.append(self[i])
|
||||
for elem in out:
|
||||
@ -70,18 +70,18 @@ def adaptive_avg_pool2d(self: List[int], out: List[int]):
|
||||
return shape
|
||||
|
||||
|
||||
def _copy(self: List[int]):
|
||||
out: List[int] = []
|
||||
def _copy(self: list[int]):
|
||||
out: list[int] = []
|
||||
for elem in self:
|
||||
out.append(elem)
|
||||
return out
|
||||
|
||||
|
||||
def unary(self: List[int]):
|
||||
def unary(self: list[int]):
|
||||
return _copy(self)
|
||||
|
||||
|
||||
def broadcast_inplace(a: List[int], b: List[int]):
|
||||
def broadcast_inplace(a: list[int], b: list[int]):
|
||||
dimsA = len(a)
|
||||
dimsB = len(b)
|
||||
if dimsB > dimsA:
|
||||
@ -101,13 +101,13 @@ def broadcast_inplace(a: List[int], b: List[int]):
|
||||
return _copy(a)
|
||||
|
||||
|
||||
def expand(self: List[int], sizes: List[int]):
|
||||
def expand(self: list[int], sizes: list[int]):
|
||||
assert len(sizes) >= len(self)
|
||||
ndim = len(sizes)
|
||||
tensor_dim = len(self)
|
||||
if ndim == 0:
|
||||
return _copy(sizes)
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for i in range(ndim):
|
||||
offset = ndim - 1 - i
|
||||
dim = tensor_dim - 1 - offset
|
||||
@ -123,11 +123,11 @@ def expand(self: List[int], sizes: List[int]):
|
||||
return out
|
||||
|
||||
|
||||
def expand_one_unused(self: List[int], sizes: List[int], inp0: Any):
|
||||
def expand_one_unused(self: list[int], sizes: list[int], inp0: Any):
|
||||
return expand(self, sizes)
|
||||
|
||||
|
||||
def infer_size_impl(shape: List[int], numel: int) -> List[int]:
|
||||
def infer_size_impl(shape: list[int], numel: int) -> list[int]:
|
||||
newsize = 1
|
||||
infer_dim: Optional[int] = None
|
||||
for dim in range(len(shape)):
|
||||
@ -150,27 +150,27 @@ def infer_size_impl(shape: List[int], numel: int) -> List[int]:
|
||||
return out
|
||||
|
||||
|
||||
def numel(sizes: List[int]):
|
||||
def numel(sizes: list[int]):
|
||||
numel = 1
|
||||
for elem in sizes:
|
||||
numel *= elem
|
||||
return numel
|
||||
|
||||
|
||||
def view(self: List[int], sizes: List[int]):
|
||||
def view(self: list[int], sizes: list[int]):
|
||||
return infer_size_impl(sizes, numel(self))
|
||||
|
||||
|
||||
def view_one_unused(self: List[int], sizes: List[int], *, implicit: bool = False):
|
||||
def view_one_unused(self: list[int], sizes: list[int], *, implicit: bool = False):
|
||||
return view(self, sizes)
|
||||
|
||||
|
||||
def sum_mean_dim(
|
||||
self: List[int], opt_dims: Optional[List[int]], keep_dim: bool, dt: Any
|
||||
self: list[int], opt_dims: Optional[list[int]], keep_dim: bool, dt: Any
|
||||
):
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
if opt_dims is None or len(opt_dims) == 0:
|
||||
dims: List[int] = list(range(len(self)))
|
||||
dims: list[int] = list(range(len(self)))
|
||||
else:
|
||||
dims = opt_dims
|
||||
|
||||
@ -187,7 +187,7 @@ def sum_mean_dim(
|
||||
return out
|
||||
|
||||
|
||||
def max_dim(self: List[int], dim: int, keep_dim: bool):
|
||||
def max_dim(self: list[int], dim: int, keep_dim: bool):
|
||||
out = sum_mean_dim(self, [dim], keep_dim, None)
|
||||
return out, out
|
||||
|
||||
@ -239,7 +239,7 @@ def pooling_output_shape(
|
||||
|
||||
|
||||
def pool2d_shape_check(
|
||||
input: List[int],
|
||||
input: list[int],
|
||||
kH: int,
|
||||
kW: int,
|
||||
dH: int,
|
||||
@ -273,11 +273,11 @@ def pool2d_shape_check(
|
||||
|
||||
|
||||
def max_pool2d(
|
||||
input: List[int],
|
||||
kernel_size: List[int],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
kernel_size: list[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
ceil_mode: bool,
|
||||
):
|
||||
assert (
|
||||
@ -343,11 +343,11 @@ def max_pool2d(
|
||||
|
||||
|
||||
def max_pool2d_with_indices(
|
||||
input: List[int],
|
||||
kernel_size: List[int],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
kernel_size: list[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
ceil_mode: bool,
|
||||
):
|
||||
out = max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
|
||||
@ -355,11 +355,11 @@ def max_pool2d_with_indices(
|
||||
|
||||
|
||||
def upsample_nearest2d(
|
||||
input: List[int],
|
||||
output_size: Optional[List[int]],
|
||||
scale_factors: Optional[List[float]],
|
||||
input: list[int],
|
||||
output_size: Optional[list[int]],
|
||||
scale_factors: Optional[list[float]],
|
||||
):
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
out.append(input[0])
|
||||
out.append(input[1])
|
||||
|
||||
@ -385,7 +385,7 @@ def upsample_nearest2d(
|
||||
return out
|
||||
|
||||
|
||||
def mm(self: List[int], mat2: List[int]):
|
||||
def mm(self: list[int], mat2: list[int]):
|
||||
assert len(self) == 2, "self must be a matrix"
|
||||
assert len(mat2) == 2, "mat2 must be a matrix"
|
||||
|
||||
@ -393,37 +393,37 @@ def mm(self: List[int], mat2: List[int]):
|
||||
return [self[0], mat2[1]]
|
||||
|
||||
|
||||
def dot(self: List[int], tensor: List[int]):
|
||||
def dot(self: list[int], tensor: list[int]):
|
||||
assert len(self) == 1 and len(tensor) == 1
|
||||
assert self[0] == tensor[0]
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
return out
|
||||
|
||||
|
||||
def mv(self: List[int], vec: List[int]):
|
||||
def mv(self: list[int], vec: list[int]):
|
||||
assert len(self) == 2 and len(vec) == 1
|
||||
assert self[1] == vec[0]
|
||||
# TODO: return self
|
||||
return [self[0]]
|
||||
|
||||
|
||||
def unsqueeze(li: List[int], dim: int):
|
||||
def unsqueeze(li: list[int], dim: int):
|
||||
dim = maybe_wrap_dim(dim, len(li) + 1)
|
||||
out = _copy(li)
|
||||
out.insert(dim, 1)
|
||||
return out
|
||||
|
||||
|
||||
def squeeze_nodim(li: List[int]):
|
||||
out: List[int] = []
|
||||
def squeeze_nodim(li: list[int]):
|
||||
out: list[int] = []
|
||||
for i in range(len(li)):
|
||||
if li[i] != 1:
|
||||
out.append(li[i])
|
||||
return out
|
||||
|
||||
|
||||
def squeeze(li: List[int], dim: int):
|
||||
out: List[int] = []
|
||||
def squeeze(li: list[int], dim: int):
|
||||
out: list[int] = []
|
||||
wrapped_dim = maybe_wrap_dim(dim, len(li))
|
||||
for i in range(len(li)):
|
||||
if i == wrapped_dim:
|
||||
@ -434,13 +434,13 @@ def squeeze(li: List[int], dim: int):
|
||||
return out
|
||||
|
||||
|
||||
def squeeze_dims(li: List[int], dims: List[int]):
|
||||
def squeeze_dims(li: list[int], dims: list[int]):
|
||||
if len(dims) == 0:
|
||||
return li
|
||||
wrapped_dims = _copy(dims)
|
||||
for i in range(len(dims)):
|
||||
wrapped_dims[i] = maybe_wrap_dim(wrapped_dims[i], len(li))
|
||||
result: List[int] = []
|
||||
result: list[int] = []
|
||||
for i in range(len(li)):
|
||||
if li[i] == 1:
|
||||
if i not in wrapped_dims:
|
||||
@ -450,12 +450,12 @@ def squeeze_dims(li: List[int], dims: List[int]):
|
||||
return result
|
||||
|
||||
|
||||
def index_select(self: List[int], dim: int, index: List[int]):
|
||||
def index_select(self: list[int], dim: int, index: list[int]):
|
||||
dim = maybe_wrap_dim(dim, len(self))
|
||||
numel = multiply_integers(index)
|
||||
assert len(index) <= 1
|
||||
assert dim == 0 or dim < len(self)
|
||||
result_size: List[int] = []
|
||||
result_size: list[int] = []
|
||||
for i in range(len(self)):
|
||||
if dim == i:
|
||||
result_size.append(numel)
|
||||
@ -465,8 +465,8 @@ def index_select(self: List[int], dim: int, index: List[int]):
|
||||
|
||||
|
||||
def embedding(
|
||||
weight: List[int],
|
||||
indices: List[int],
|
||||
weight: list[int],
|
||||
indices: list[int],
|
||||
padding_idx: int = -1,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
@ -484,7 +484,7 @@ def max_int():
|
||||
|
||||
|
||||
def slice(
|
||||
self: List[int], dim: int, start: Optional[int], end: Optional[int], step: int
|
||||
self: list[int], dim: int, start: Optional[int], end: Optional[int], step: int
|
||||
):
|
||||
ndim = len(self)
|
||||
assert ndim != 0
|
||||
@ -512,12 +512,12 @@ def slice(
|
||||
return out
|
||||
|
||||
|
||||
def check_cat_no_zero_dim(tensors: List[List[int]]):
|
||||
def check_cat_no_zero_dim(tensors: list[list[int]]):
|
||||
for tensor in tensors:
|
||||
assert len(tensor) > 0
|
||||
|
||||
|
||||
def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
|
||||
def legacy_cat_wrap_dim(dim: int, tensor_sizes: list[list[int]]):
|
||||
out_dim: Optional[int] = None
|
||||
for size in tensor_sizes:
|
||||
if not (len(size) == 1 and size[0] == 0):
|
||||
@ -528,12 +528,12 @@ def legacy_cat_wrap_dim(dim: int, tensor_sizes: List[List[int]]):
|
||||
return out_dim
|
||||
|
||||
|
||||
def should_skip(tensor: List[int]):
|
||||
def should_skip(tensor: list[int]):
|
||||
return numel(tensor) == 0 and len(tensor) == 1
|
||||
|
||||
|
||||
def check_cat_shape_except_dim(
|
||||
first: List[int], second: List[int], dimension: int, index: int
|
||||
first: list[int], second: list[int], dimension: int, index: int
|
||||
):
|
||||
first_dims = len(first)
|
||||
second_dims = len(second)
|
||||
@ -545,11 +545,11 @@ def check_cat_shape_except_dim(
|
||||
), "Sizes of tensors must match except in dimension"
|
||||
|
||||
|
||||
def cat(tensors: List[List[int]], dim: int):
|
||||
def cat(tensors: list[list[int]], dim: int):
|
||||
check_cat_no_zero_dim(tensors)
|
||||
dim = legacy_cat_wrap_dim(dim, tensors)
|
||||
assert len(tensors) > 0
|
||||
not_skipped_tensor: Optional[List[int]] = None
|
||||
not_skipped_tensor: Optional[list[int]] = None
|
||||
for tensor in tensors:
|
||||
if not should_skip(tensor):
|
||||
not_skipped_tensor = tensor
|
||||
@ -569,15 +569,15 @@ def cat(tensors: List[List[int]], dim: int):
|
||||
return result_size
|
||||
|
||||
|
||||
def stack(tensors: List[List[int]], dim: int):
|
||||
unsqueezed_tensors: List[List[int]] = []
|
||||
def stack(tensors: list[list[int]], dim: int):
|
||||
unsqueezed_tensors: list[list[int]] = []
|
||||
for tensor in tensors:
|
||||
unsqueezed = unsqueeze(tensor, dim)
|
||||
unsqueezed_tensors.append(unsqueezed)
|
||||
return cat(unsqueezed_tensors, dim)
|
||||
|
||||
|
||||
def select(self: List[int], dim: int, index: int):
|
||||
def select(self: list[int], dim: int, index: int):
|
||||
ndim = len(self)
|
||||
assert ndim != 0
|
||||
dim = maybe_wrap_dim(dim, ndim)
|
||||
@ -585,14 +585,14 @@ def select(self: List[int], dim: int, index: int):
|
||||
assert not (index < -size or index >= size)
|
||||
if index < 0:
|
||||
index += size
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for i in range(ndim):
|
||||
if i != dim:
|
||||
out.append(self[i])
|
||||
return out
|
||||
|
||||
|
||||
def matmul(tensor1: List[int], tensor2: List[int]):
|
||||
def matmul(tensor1: list[int], tensor2: list[int]):
|
||||
dim_tensor1 = len(tensor1)
|
||||
dim_tensor2 = len(tensor2)
|
||||
if dim_tensor1 == 1 and dim_tensor2 == 1:
|
||||
@ -607,12 +607,12 @@ def matmul(tensor1: List[int], tensor2: List[int]):
|
||||
# We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
||||
# we track m1 vs m2 separately even though they must match for nicer error messages
|
||||
n = tensor1[-2] if dim_tensor1 > 1 else 1
|
||||
batch_tensor1: List[int] = []
|
||||
batch_tensor1: list[int] = []
|
||||
# TODO: handling of slice
|
||||
for i in range(dim_tensor1 - 2):
|
||||
batch_tensor1.append(tensor1[i])
|
||||
p = tensor2[-1]
|
||||
batch_tensor2: List[int] = []
|
||||
batch_tensor2: list[int] = []
|
||||
# TODO: handling of slice
|
||||
for i in range(dim_tensor2 - 2):
|
||||
batch_tensor2.append(tensor2[i])
|
||||
@ -633,11 +633,11 @@ def matmul(tensor1: List[int], tensor2: List[int]):
|
||||
assert False, "both arguments to matmul need to be at least 1D"
|
||||
|
||||
|
||||
def t(self: List[int]):
|
||||
def t(self: list[int]):
|
||||
assert len(self) <= 2
|
||||
self_len = len(self)
|
||||
if self_len == 0:
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
return out
|
||||
elif self_len == 1:
|
||||
return [self[0]]
|
||||
@ -645,13 +645,13 @@ def t(self: List[int]):
|
||||
return [self[1], self[0]]
|
||||
|
||||
|
||||
def transpose(self: List[int], dim0: int, dim1: int):
|
||||
def transpose(self: list[int], dim0: int, dim1: int):
|
||||
ndims = len(self)
|
||||
dim0 = maybe_wrap_dim(dim0, ndims)
|
||||
dim1 = maybe_wrap_dim(dim1, ndims)
|
||||
if dim0 == dim1:
|
||||
return _copy(self)
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for i in range(ndims):
|
||||
if i == dim0:
|
||||
out.append(self[dim1])
|
||||
@ -662,18 +662,18 @@ def transpose(self: List[int], dim0: int, dim1: int):
|
||||
return out
|
||||
|
||||
|
||||
def linear(input: List[int], weight: List[int], bias: Optional[List[int]]):
|
||||
def linear(input: list[int], weight: list[int], bias: Optional[list[int]]):
|
||||
out = matmul(input, t(weight))
|
||||
if bias is not None:
|
||||
assert broadcast(bias, out) == out
|
||||
return out
|
||||
|
||||
|
||||
def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
|
||||
def addmm(self: list[int], mat1: list[int], mat2: list[int], beta: Any, alpha: Any):
|
||||
return broadcast(self, mm(mat1, mat2))
|
||||
|
||||
|
||||
def check_non_negative(array: List[int]) -> bool:
|
||||
def check_non_negative(array: list[int]) -> bool:
|
||||
# TODO: look into rewriting with early return and getting loop unrolling to fire
|
||||
non_negative = False
|
||||
for val in array:
|
||||
@ -683,12 +683,12 @@ def check_non_negative(array: List[int]) -> bool:
|
||||
|
||||
|
||||
def check_shape_forward(
|
||||
input: List[int],
|
||||
weight_sizes: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight_sizes: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
):
|
||||
k = len(input)
|
||||
@ -714,12 +714,12 @@ def check_shape_forward(
|
||||
|
||||
|
||||
def conv_output_size(
|
||||
input_size: List[int],
|
||||
weight_size: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input_size: list[int],
|
||||
weight_size: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
):
|
||||
check_shape_forward(
|
||||
@ -728,7 +728,7 @@ def conv_output_size(
|
||||
|
||||
has_dilation = len(dilation) > 0
|
||||
dim = len(input_size)
|
||||
output_size: List[int] = []
|
||||
output_size: list[int] = []
|
||||
input_batch_size_dim = 0
|
||||
weight_output_channels_dim = 0
|
||||
output_size.append(input_size[input_batch_size_dim])
|
||||
@ -744,12 +744,12 @@ def conv_output_size(
|
||||
|
||||
|
||||
def conv1d(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
):
|
||||
assert len(weight) == 3
|
||||
@ -758,12 +758,12 @@ def conv1d(
|
||||
|
||||
|
||||
def conv2d(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
):
|
||||
assert len(weight) == 4
|
||||
@ -772,25 +772,25 @@ def conv2d(
|
||||
|
||||
|
||||
def conv_backwards(
|
||||
grad_output: List[int],
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
biases: Optional[List[int]],
|
||||
grad_output: list[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
biases: Optional[list[int]],
|
||||
):
|
||||
# Bias gradient is always generated regardess of if biases is supplied
|
||||
return _copy(input), _copy(weight), [grad_output[1]]
|
||||
|
||||
|
||||
def conv_transpose2d_input(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]] = None,
|
||||
stride: Optional[List[int]] = None,
|
||||
padding: Optional[List[int]] = None,
|
||||
output_padding: Optional[List[int]] = None,
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]] = None,
|
||||
stride: Optional[list[int]] = None,
|
||||
padding: Optional[list[int]] = None,
|
||||
output_padding: Optional[list[int]] = None,
|
||||
groups: int = 1,
|
||||
dilation: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
dilation: Optional[list[int]] = None,
|
||||
) -> list[int]:
|
||||
if stride is None:
|
||||
stride = [1, 1]
|
||||
if padding is None:
|
||||
@ -801,7 +801,7 @@ def conv_transpose2d_input(
|
||||
dilation = [1, 1]
|
||||
has_dilation = len(dilation) > 0
|
||||
dim = len(input)
|
||||
output_size: List[int] = []
|
||||
output_size: list[int] = []
|
||||
input_batch_size_dim = 0
|
||||
weight_output_channels_dim = 1
|
||||
output_size.append(input[input_batch_size_dim])
|
||||
@ -821,20 +821,20 @@ def conv_transpose2d_input(
|
||||
|
||||
|
||||
def conv_forwards(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
transposed: bool,
|
||||
output_padding: List[int],
|
||||
output_padding: list[int],
|
||||
groups: int,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
has_dilation = len(dilation) > 0
|
||||
has_output_padding = len(output_padding) > 0
|
||||
dim = len(input)
|
||||
output_size: List[int] = []
|
||||
output_size: list[int] = []
|
||||
input_batch_size_dim = 0
|
||||
weight_output_channels_dim = 1 if transposed else 0
|
||||
output_size.append(input[input_batch_size_dim])
|
||||
@ -864,20 +864,20 @@ def conv_forwards(
|
||||
|
||||
|
||||
def _conv_forwards(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
transposed: bool,
|
||||
output_padding: List[int],
|
||||
output_padding: list[int],
|
||||
groups: int,
|
||||
benchmark: bool,
|
||||
deterministic: bool,
|
||||
cudnn_enabled: bool,
|
||||
allow_tf32: bool,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
return conv_forwards(
|
||||
input,
|
||||
weight,
|
||||
@ -892,29 +892,29 @@ def _conv_forwards(
|
||||
|
||||
|
||||
def batch_norm(
|
||||
input: List[int],
|
||||
weight: Optional[List[int]],
|
||||
bias: Optional[List[int]],
|
||||
running_mean: Optional[List[int]],
|
||||
running_var: Optional[List[int]],
|
||||
input: list[int],
|
||||
weight: Optional[list[int]],
|
||||
bias: Optional[list[int]],
|
||||
running_mean: Optional[list[int]],
|
||||
running_var: Optional[list[int]],
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
cudnn_enabled: bool,
|
||||
):
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for elem in input:
|
||||
out.append(elem)
|
||||
return out
|
||||
|
||||
|
||||
def conv3d(
|
||||
input: List[int],
|
||||
weight: List[int],
|
||||
bias: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
input: list[int],
|
||||
weight: list[int],
|
||||
bias: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
):
|
||||
assert len(weight) == 5
|
||||
@ -935,11 +935,11 @@ def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
|
||||
|
||||
|
||||
def zero_dim_tensor(input: Any):
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
return out
|
||||
|
||||
|
||||
def multiply_integers(li: List[int]):
|
||||
def multiply_integers(li: list[int]):
|
||||
out = 1
|
||||
for elem in li:
|
||||
out = out * elem
|
||||
@ -970,11 +970,11 @@ def arange_start_step(
|
||||
return [int(math.ceil((end - start) / step))]
|
||||
|
||||
|
||||
def permute(input: List[int], dims: List[int]):
|
||||
def permute(input: list[int], dims: list[int]):
|
||||
assert len(input) == len(dims)
|
||||
ndim = len(dims)
|
||||
seen_dims: List[int] = []
|
||||
newSizes: List[int] = []
|
||||
seen_dims: list[int] = []
|
||||
newSizes: list[int] = []
|
||||
for i in range(ndim):
|
||||
dim = maybe_wrap_dim(dims[i], ndim)
|
||||
seen_dims.append(dim)
|
||||
@ -985,12 +985,12 @@ def permute(input: List[int], dims: List[int]):
|
||||
return newSizes
|
||||
|
||||
|
||||
def movedim(self: List[int], source: List[int], destination: List[int]) -> List[int]:
|
||||
def movedim(self: list[int], source: list[int], destination: list[int]) -> list[int]:
|
||||
self_dim = len(self)
|
||||
if self_dim <= 1:
|
||||
return self
|
||||
normalized_src: List[int] = []
|
||||
normalized_dst: List[int] = []
|
||||
normalized_src: list[int] = []
|
||||
normalized_dst: list[int] = []
|
||||
for i in range(len(source)):
|
||||
normalized_src.append(maybe_wrap_dim(source[i], self_dim))
|
||||
normalized_dst.append(maybe_wrap_dim(destination[i], self_dim))
|
||||
@ -1003,8 +1003,8 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[
|
||||
src_dims[normalized_src[i]] = -1
|
||||
dst_dims[normalized_dst[i]] = -1
|
||||
|
||||
source_dims: List[int] = []
|
||||
destination_dims: List[int] = []
|
||||
source_dims: list[int] = []
|
||||
destination_dims: list[int] = []
|
||||
for ele in src_dims:
|
||||
if ele != -1:
|
||||
source_dims.append(ele)
|
||||
@ -1018,7 +1018,7 @@ def movedim(self: List[int], source: List[int], destination: List[int]) -> List[
|
||||
return permute(self, order)
|
||||
|
||||
|
||||
def flatten(input: List[int], start_dim: int, end_dim: int):
|
||||
def flatten(input: list[int], start_dim: int, end_dim: int):
|
||||
start_dim = maybe_wrap_dim(start_dim, len(input))
|
||||
end_dim = maybe_wrap_dim(end_dim, len(input))
|
||||
assert start_dim <= end_dim
|
||||
@ -1026,7 +1026,7 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
|
||||
return [1]
|
||||
if start_dim == end_dim:
|
||||
# TODO: return self
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for elem in input:
|
||||
out.append(elem)
|
||||
return out
|
||||
@ -1035,7 +1035,7 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
|
||||
slice_numel *= input[i]
|
||||
# TODO: use slicing when slice optimization has landed
|
||||
# slice_numel = multiply_integers(input[start_dim:end_dim - start_dim + 1])
|
||||
shape: List[int] = []
|
||||
shape: list[int] = []
|
||||
for i in range(start_dim):
|
||||
shape.append(input[i])
|
||||
shape.append(slice_numel)
|
||||
@ -1044,17 +1044,17 @@ def flatten(input: List[int], start_dim: int, end_dim: int):
|
||||
return shape
|
||||
|
||||
|
||||
def nonzero_lower_bound(input: List[int]):
|
||||
def nonzero_lower_bound(input: list[int]):
|
||||
return [0, len(input)]
|
||||
|
||||
|
||||
def nonzero_upper_bound(input: List[int]):
|
||||
def nonzero_upper_bound(input: list[int]):
|
||||
return [numel(input), len(input)]
|
||||
|
||||
|
||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
def _reduce_along_dim(self: list[int], dim: int, keepdim: bool):
|
||||
dim = maybe_wrap_dim(dim, len(self))
|
||||
out: List[int] = []
|
||||
out: list[int] = []
|
||||
for i, self_dim in enumerate(self):
|
||||
if i == dim:
|
||||
if keepdim:
|
||||
@ -1065,14 +1065,14 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
|
||||
|
||||
def argmax(
|
||||
self: List[int], dim: Optional[int] = None, keepdim: bool = False
|
||||
) -> List[int]:
|
||||
self: list[int], dim: Optional[int] = None, keepdim: bool = False
|
||||
) -> list[int]:
|
||||
if dim is None:
|
||||
return []
|
||||
return _reduce_along_dim(self, dim, keepdim)
|
||||
|
||||
|
||||
def bmm(self: List[int], mat2: List[int]) -> List[int]:
|
||||
def bmm(self: list[int], mat2: list[int]) -> list[int]:
|
||||
assert len(self) == 3, "bmm only supports 3D tensors"
|
||||
assert len(mat2) == 3, "bmm only supports 3D tensors"
|
||||
assert self[0] == mat2[0], "mismatching batch dimension"
|
||||
@ -1080,13 +1080,13 @@ def bmm(self: List[int], mat2: List[int]) -> List[int]:
|
||||
return [self[0], self[1], mat2[2]]
|
||||
|
||||
|
||||
def _shape_as_tensor(self: List[int]) -> List[int]:
|
||||
def _shape_as_tensor(self: list[int]) -> list[int]:
|
||||
return [len(self)]
|
||||
|
||||
|
||||
def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
|
||||
def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]:
|
||||
if len(self) == 0:
|
||||
result: List[int] = []
|
||||
result: list[int] = []
|
||||
else:
|
||||
assert (
|
||||
k <= self[dim]
|
||||
@ -1097,8 +1097,8 @@ def topk(self: List[int], k: int, dim: int = -1) -> Tuple[List[int], List[int]]:
|
||||
|
||||
|
||||
def nll_loss_forward(
|
||||
self: List[int], target: List[int], weight: Optional[List[int]], reduction: int
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
self: list[int], target: list[int], weight: Optional[list[int]], reduction: int
|
||||
) -> tuple[list[int], list[int]]:
|
||||
# This is taken shamelessly from the meta function in LossNLL.cpp
|
||||
self_dim = len(self)
|
||||
target_dim = len(target)
|
||||
@ -1107,7 +1107,7 @@ def nll_loss_forward(
|
||||
no_batch_dim = self_dim == 1 and target_dim == 0
|
||||
assert no_batch_dim or (self[0] == target[0])
|
||||
n_classes = self[-1]
|
||||
scalar_shape: List[int] = []
|
||||
scalar_shape: list[int] = []
|
||||
assert weight is None or (len(weight) == 1 and weight[0] == n_classes)
|
||||
if reduction == 0 and self_dim == 2:
|
||||
reduction_shape = [self[0]]
|
||||
@ -1117,9 +1117,9 @@ def nll_loss_forward(
|
||||
|
||||
|
||||
def native_layer_norm(
|
||||
input: List[int], normalized_shape: List[int]
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
reduction_shape: List[int] = []
|
||||
input: list[int], normalized_shape: list[int]
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
reduction_shape: list[int] = []
|
||||
num_unreduced_dimensions = len(input) - len(normalized_shape)
|
||||
assert num_unreduced_dimensions >= 0
|
||||
for i in range(num_unreduced_dimensions):
|
||||
@ -1130,13 +1130,13 @@ def native_layer_norm(
|
||||
|
||||
|
||||
def native_batch_norm(
|
||||
input: List[int],
|
||||
weight: Optional[List[int]],
|
||||
bias: Optional[List[int]],
|
||||
running_mean: Optional[List[int]],
|
||||
running_var: Optional[List[int]],
|
||||
input: list[int],
|
||||
weight: Optional[list[int]],
|
||||
bias: Optional[list[int]],
|
||||
running_mean: Optional[list[int]],
|
||||
running_var: Optional[list[int]],
|
||||
training: bool,
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
if training:
|
||||
_size = [input[1]]
|
||||
else:
|
||||
@ -1145,24 +1145,24 @@ def native_batch_norm(
|
||||
|
||||
|
||||
def _batch_norm_with_update(
|
||||
input: List[int],
|
||||
weight: Optional[List[int]],
|
||||
bias: Optional[List[int]],
|
||||
running_mean: Optional[List[int]],
|
||||
running_var: Optional[List[int]],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
input: list[int],
|
||||
weight: Optional[list[int]],
|
||||
bias: Optional[list[int]],
|
||||
running_mean: Optional[list[int]],
|
||||
running_var: Optional[list[int]],
|
||||
) -> tuple[list[int], list[int], list[int], list[int]]:
|
||||
_size = [input[1]]
|
||||
return _copy(input), _size, _size, [0]
|
||||
|
||||
|
||||
def cross_entropy_loss(
|
||||
self: List[int],
|
||||
target: List[int],
|
||||
weight: Optional[List[int]] = None,
|
||||
self: list[int],
|
||||
target: list[int],
|
||||
weight: Optional[list[int]] = None,
|
||||
reduction: int = 1,
|
||||
ignore_index: int = -100,
|
||||
label_smoothing: float = 0.0,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
result_shape = nll_loss_forward(self, target, weight, reduction)[0]
|
||||
return result_shape
|
||||
|
||||
@ -1187,9 +1187,9 @@ def index_Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[in
|
||||
"""
|
||||
|
||||
ScriptFn = torch._C.ScriptFunction
|
||||
shape_compute_graph_mapping: Dict[str, ScriptFn] = {}
|
||||
bounded_compute_graph_mapping: Dict[str, Tuple[ScriptFn, ScriptFn]] = {}
|
||||
script_func_map: Dict[Callable, ScriptFn] = {}
|
||||
shape_compute_graph_mapping: dict[str, ScriptFn] = {}
|
||||
bounded_compute_graph_mapping: dict[str, tuple[ScriptFn, ScriptFn]] = {}
|
||||
script_func_map: dict[Callable, ScriptFn] = {}
|
||||
|
||||
|
||||
def process_func(func: Callable):
|
||||
|
@ -6,9 +6,10 @@ This module stores various pieces of Python-global state relating to the JIT.
|
||||
This is not intended to be imported directly; please the exposed
|
||||
functionalities in `torch.jit`.
|
||||
"""
|
||||
|
||||
import os
|
||||
import weakref
|
||||
from typing import Any, Dict, Type
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -62,8 +63,8 @@ _python_cu = torch._C.CompilationUnit()
|
||||
|
||||
|
||||
# python class => ScriptClass mapping
|
||||
_script_classes: Dict[Type[Any], Type[Any]] = {}
|
||||
_name_to_pyclass: Dict[str, Type[Any]] = {}
|
||||
_script_classes: dict[type[Any], type[Any]] = {}
|
||||
_name_to_pyclass: dict[str, type[Any]] = {}
|
||||
|
||||
|
||||
def _add_script_class(python_class, script_class):
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
import re
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -70,7 +70,7 @@ def _unique_state_dict(module, keep_vars=False):
|
||||
# as values, and deduplicate the params using Parameters and Buffers
|
||||
state_dict = module.state_dict(keep_vars=True)
|
||||
filtered_dict = type(state_dict)()
|
||||
seen_ids: Set[int] = set()
|
||||
seen_ids: set[int] = set()
|
||||
for k, v in state_dict.items():
|
||||
if id(v) in seen_ids:
|
||||
continue
|
||||
@ -112,7 +112,7 @@ class ONNXTracedModule(torch.nn.Module):
|
||||
outs = []
|
||||
|
||||
def wrapper(*args):
|
||||
in_args: List[torch.Tensor] = []
|
||||
in_args: list[torch.Tensor] = []
|
||||
for i in range(len(in_vars)):
|
||||
if not isinstance(args[i], torch.Tensor):
|
||||
raise RuntimeError("Expected Tensor argument")
|
||||
@ -960,6 +960,7 @@ def trace(
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -968,6 +969,7 @@ def trace(
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
n = Net()
|
||||
example_weight = torch.rand(1, 1, 3, 3)
|
||||
example_forward_input = torch.rand(1, 1, 3, 3)
|
||||
@ -1112,7 +1114,7 @@ def trace(
|
||||
return traced_func
|
||||
|
||||
|
||||
_trace_module_map: Optional[Dict[Any, Any]] = None
|
||||
_trace_module_map: Optional[dict[Any, Any]] = None
|
||||
|
||||
|
||||
def trace_module(
|
||||
@ -1176,6 +1178,7 @@ def trace_module(
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -1202,7 +1205,7 @@ def trace_module(
|
||||
|
||||
# Trace specific methods on a module (specified in `inputs`), constructs
|
||||
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
|
||||
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
|
||||
inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight}
|
||||
module = torch.jit.trace_module(n, inputs)
|
||||
|
||||
"""
|
||||
@ -1226,7 +1229,7 @@ def trace_module(
|
||||
|
||||
old_module_map = torch.jit._trace._trace_module_map
|
||||
try:
|
||||
trace_module_map: Dict[Any, Any] = {}
|
||||
trace_module_map: dict[Any, Any] = {}
|
||||
|
||||
def register_submods(mod, prefix):
|
||||
for name, child in mod.named_children():
|
||||
|
@ -8,7 +8,6 @@ import re
|
||||
import typing
|
||||
import warnings
|
||||
from textwrap import dedent
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from torch._C import (
|
||||
@ -350,7 +349,7 @@ def try_real_annotations(fn, loc):
|
||||
|
||||
# Finds common type for enum values belonging to an Enum class. If not all
|
||||
# values have the same type, AnyType is returned.
|
||||
def get_enum_value_type(e: Type[enum.Enum], loc):
|
||||
def get_enum_value_type(e: type[enum.Enum], loc):
|
||||
enum_values: List[enum.Enum] = list(e)
|
||||
if not enum_values:
|
||||
raise ValueError(f"No enum values defined for: '{e.__class__}'")
|
||||
|
@ -4,7 +4,6 @@ import dataclasses
|
||||
import inspect
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import List, Tuple # noqa: F401
|
||||
@ -1128,10 +1127,7 @@ class ExprBuilder(Builder):
|
||||
return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
|
||||
elif sub_type is ast.ExtSlice:
|
||||
return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
|
||||
elif sys.version_info >= (
|
||||
3,
|
||||
9,
|
||||
): # In Python3.9 array indicies are not wrapped in ast.Index
|
||||
else: # In Python3.9 array indicies are not wrapped in ast.Index
|
||||
if sub_type is ast.Tuple:
|
||||
# N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
|
||||
indices = []
|
||||
@ -1150,8 +1146,6 @@ class ExprBuilder(Builder):
|
||||
indices.append(tup)
|
||||
return Subscript(base, indices)
|
||||
return Subscript(base, [build_expr(ctx, expr.slice)])
|
||||
else: # Ellipsis (can only happen in Python 2)
|
||||
raise NotSupportedError(base.range(), "ellipsis is not supported")
|
||||
|
||||
@staticmethod
|
||||
def build_List(ctx, expr):
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List
|
||||
|
||||
from torch._C import _compile_graph_to_code_table, _generate_upgraders_graph
|
||||
|
||||
@ -20,7 +19,7 @@ def format_bytecode(table):
|
||||
return formatted_table
|
||||
|
||||
|
||||
def generate_upgraders_bytecode() -> List:
|
||||
def generate_upgraders_bytecode() -> list:
|
||||
yaml_content = []
|
||||
upgraders_graph_map = _generate_upgraders_graph()
|
||||
for upgrader_name, upgrader_graph in upgraders_graph_map.items():
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from textwrap import dedent
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch.jit
|
||||
|
||||
@ -37,7 +37,7 @@ def _gen_unsupported_methods_properties():
|
||||
sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower())
|
||||
for attr in sorted_tensor_attrs:
|
||||
funcs_str = funcs_template.format(op=attr)
|
||||
scope: Dict[str, Any] = {}
|
||||
scope: dict[str, Any] = {}
|
||||
execWrapper(funcs_str, globals(), scope)
|
||||
try:
|
||||
torch.jit.CompilationUnit(funcs_str)
|
||||
|
Reference in New Issue
Block a user