mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] Finish typing common.py (#146225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146225 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
dca5cc0255
commit
3a67c0e48d
@ -1,4 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
@ -32,7 +31,6 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -42,6 +40,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
@ -57,7 +56,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import FixedLayout, IRNode
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
@ -121,7 +122,7 @@ class WorkspaceArg:
|
||||
dtype: torch.dtype = torch.uint8
|
||||
|
||||
@staticmethod
|
||||
def unique_name(prefix="workspace_") -> str:
|
||||
def unique_name(prefix: str = "workspace_") -> str:
|
||||
return f"{prefix}{next(V.graph.workspace_id)}"
|
||||
|
||||
@staticmethod
|
||||
@ -456,7 +457,9 @@ def init_backend_registration() -> None:
|
||||
|
||||
|
||||
def index_prevent_reordering(
|
||||
index: list[sympy.Expr], index_vars, sizes
|
||||
index: Sequence[sympy.Expr],
|
||||
index_vars: Sequence[sympy.Expr],
|
||||
sizes: Sequence[sympy.Expr],
|
||||
) -> list[sympy.Expr]:
|
||||
from ..ir import FlexibleLayout
|
||||
|
||||
@ -631,11 +634,11 @@ class DataTypePropagation:
|
||||
return self.propagate_graph(self.graphs["root"])
|
||||
|
||||
@classmethod
|
||||
def propagate_loopbody(cls, body) -> Optional[torch.dtype]:
|
||||
def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
|
||||
return cls(body).propagate()
|
||||
|
||||
@classmethod
|
||||
def propagate_scheduler_node(cls, node) -> Optional[torch.dtype]:
|
||||
def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import SchedulerNode
|
||||
|
||||
@ -673,7 +676,7 @@ class OpDecompositions:
|
||||
return ops.mul(x, x)
|
||||
|
||||
@staticmethod
|
||||
def erfc(x: OpVarT):
|
||||
def erfc(x: OpVarT) -> OpVarT:
|
||||
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
||||
|
||||
@staticmethod
|
||||
@ -759,7 +762,7 @@ def _all_in_parens(string: str) -> bool:
|
||||
|
||||
|
||||
class OpOverrides(OpDecompositions):
|
||||
def __init__(self, parent):
|
||||
def __init__(self, parent: OpsHandler[OpVarT]) -> None:
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
@ -1109,12 +1112,12 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
|
||||
class DeferredLine(DeferredLineBase):
|
||||
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
||||
|
||||
def __init__(self, name, line):
|
||||
def __init__(self, name: str, line: str):
|
||||
super().__init__(line)
|
||||
self.name = name
|
||||
assert not isinstance(line, DeferredLineBase)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> Optional[str]:
|
||||
if all(
|
||||
self.name not in x
|
||||
for x in (
|
||||
@ -1127,14 +1130,14 @@ class DeferredLine(DeferredLineBase):
|
||||
return self.line
|
||||
return None
|
||||
|
||||
def _new_line(self, line):
|
||||
def _new_line(self, line: str) -> DeferredLine:
|
||||
return DeferredLine(self.name, line)
|
||||
|
||||
|
||||
class BracesBuffer(IndentedBuffer):
|
||||
def indent(self, offset=1) -> contextlib.AbstractContextManager[None]:
|
||||
def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
def ctx() -> Iterator[None]:
|
||||
for _ in range(offset):
|
||||
self.writeline("{")
|
||||
self._indent += 1
|
||||
@ -1163,7 +1166,7 @@ class ArgName:
|
||||
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
|
||||
is_constexpr: bool = False
|
||||
|
||||
def full_name(self):
|
||||
def full_name(self) -> str:
|
||||
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
|
||||
|
||||
|
||||
@ -1509,8 +1512,8 @@ class CSEVariable:
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(other) == type(self) and other.name == self.name
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, CSEVariable) and other.name == self.name
|
||||
|
||||
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
|
||||
pass
|
||||
@ -1557,7 +1560,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
||||
self.invalidated_stores: OrderedSet[str] = OrderedSet()
|
||||
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
|
||||
|
||||
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
|
||||
def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
|
||||
for name, tmp in [*self.store_cache.items()]:
|
||||
if tmp not in keep_vars:
|
||||
del self.store_cache[name]
|
||||
@ -1578,6 +1581,14 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
||||
reduction_cache=self.reduction_cache,
|
||||
)
|
||||
|
||||
def scoped_copy(self) -> typing.Self:
|
||||
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
|
||||
new_cse = self.clone()
|
||||
new_cse._cache = ScopedDict(self._cache)
|
||||
new_cse.reduction_cache = ScopedDict(self.reduction_cache)
|
||||
new_cse.store_cache = ScopedDict(self.store_cache)
|
||||
return new_cse
|
||||
|
||||
def augment_key(self, cache_key: str) -> AugmentedKeyT:
|
||||
"Override this method to augment cache key with backend specifics"
|
||||
return cast(AugmentedKeyT, cache_key)
|
||||
@ -1730,7 +1741,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.kernel_name: Optional[str] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_current_node(self, node):
|
||||
def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
|
||||
prior = self.current_node
|
||||
self.current_node = node
|
||||
self.node_to_bounds = node._body.bounds().get_bounds()
|
||||
@ -1740,16 +1751,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.current_node = prior
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_buffers(self, lb, cb=None, sb=None):
|
||||
def scope_cse(cse: CSE[CSEVariableType, Any]):
|
||||
new_cse = cse.clone()
|
||||
new_cse._cache = ScopedDict(cse._cache)
|
||||
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
||||
new_cse.store_cache = ScopedDict(cse.store_cache)
|
||||
return new_cse
|
||||
|
||||
def swap_buffers(
|
||||
self,
|
||||
lb: IndentedBuffer,
|
||||
cb: Optional[IndentedBuffer] = None,
|
||||
sb: Optional[IndentedBuffer] = None,
|
||||
) -> Iterator[None]:
|
||||
if cb is None:
|
||||
cb = lb
|
||||
if disallow_stores := sb is None:
|
||||
sb = IndentedBuffer()
|
||||
loads = self.loads
|
||||
compute = self.compute
|
||||
stores = self.stores
|
||||
@ -1757,7 +1768,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.loads = lb
|
||||
self.compute = cb
|
||||
self.stores = sb
|
||||
self.cse = scope_cse(cse)
|
||||
self.cse = cse.scoped_copy()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -1765,11 +1776,13 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.compute = compute
|
||||
self.stores = stores
|
||||
self.cse = cse
|
||||
if disallow_stores:
|
||||
assert not sb, "unexpected store inside swap_buffers"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
||||
raise NotImplementedError
|
||||
|
||||
def indirect_load(self, name: str, index: sympy.Expr):
|
||||
def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
||||
"""A load the depends on an index we have read"""
|
||||
prior = self.loads
|
||||
try:
|
||||
@ -1779,7 +1792,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
finally:
|
||||
self.loads = prior
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def store(
|
||||
@ -1815,7 +1828,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
) -> tuple[CSEVariable, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
def var_ranges(self):
|
||||
def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
|
||||
raise NotImplementedError
|
||||
|
||||
def bucketize(
|
||||
@ -1869,21 +1882,21 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
|
||||
def check_bounds(
|
||||
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def index_to_str(self, index: sympy.Expr) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> typing.Self:
|
||||
# TODO: hoist this to top level
|
||||
class CSEProxy:
|
||||
self.name = "CSEProxy"
|
||||
name = "CSEProxy"
|
||||
vr_analysis = ValueRangeAnalysis()
|
||||
|
||||
@staticmethod
|
||||
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||
def inner(*args, **kwargs):
|
||||
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
|
||||
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
@ -1891,7 +1904,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
|
||||
output_idx = 0
|
||||
|
||||
def do_cse(v):
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
@ -1950,7 +1963,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def _bound_variable(name, *args, **kwargs):
|
||||
def _bound_variable(
|
||||
name: str, *args: Any, **kwargs: Any
|
||||
) -> ValueRanges[Any]:
|
||||
"""
|
||||
If the variable comes from an FX node, we forward the bound we have already computed
|
||||
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
||||
@ -1979,7 +1994,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
# If there is no FX bound but we know how to compute one we do so
|
||||
assert not kwargs
|
||||
|
||||
def arg_to_bound(x):
|
||||
def arg_to_bound(x: Any) -> Any:
|
||||
if isinstance(x, CSEVariable):
|
||||
return x.bounds
|
||||
elif isinstance(x, sympy.Expr):
|
||||
@ -1996,8 +2011,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
var: CSEVariable,
|
||||
size: Union[sympy.Expr, int],
|
||||
check: bool = True,
|
||||
wrap_neg=True,
|
||||
):
|
||||
wrap_neg: bool = True,
|
||||
) -> sympy.Symbol:
|
||||
if isinstance(size, int):
|
||||
size = sympy.Integer(size)
|
||||
assert isinstance(size, sympy.Expr), size
|
||||
@ -2045,7 +2060,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
@staticmethod
|
||||
def check_bounds(
|
||||
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
):
|
||||
) -> None:
|
||||
return self.check_bounds(expr, size, lower, upper)
|
||||
|
||||
@staticmethod
|
||||
@ -2067,7 +2082,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _update_store_cache(name: str, value: CSEVariable):
|
||||
def _update_store_cache(name: str, value: CSEVariable) -> None:
|
||||
value = cast(CSEVariableType, value)
|
||||
self.cse.store_cache[name] = value
|
||||
if self.current_node and name in V.graph.name_to_buffer:
|
||||
@ -2087,7 +2102,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
@staticmethod
|
||||
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
||||
def store_reduction(
|
||||
name: str, index: sympy.Expr, value: CSEVariable
|
||||
) -> None:
|
||||
self.store_buffer_names.add(name)
|
||||
CSEProxy._update_store_cache(name, value)
|
||||
|
||||
@ -2215,7 +2232,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.remove_kernel_local_buffers()
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
@ -2271,7 +2288,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.args.inplace_buffers[name] = REMOVED
|
||||
self.removed_buffers.add(name)
|
||||
|
||||
def rename_indexing(self, index) -> sympy.Expr:
|
||||
def rename_indexing(
|
||||
self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
|
||||
) -> sympy.Expr:
|
||||
# adds the necessary kernel args for index expressions
|
||||
# and renames variables in index expressions to kernel arg names
|
||||
if isinstance(index, (list, tuple)):
|
||||
@ -2292,7 +2311,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
}
|
||||
return sympy_subs(index, replacements)
|
||||
|
||||
def create_cse_var(self, *args, **kwargs):
|
||||
def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
|
||||
return CSEVariable(*args, **kwargs)
|
||||
|
||||
def arg_name(self, node: IRNode) -> Optional[str]:
|
||||
@ -2313,7 +2332,7 @@ class OptimizationContext:
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def jinja2_env():
|
||||
def jinja2_env() -> Any:
|
||||
try:
|
||||
import jinja2
|
||||
|
||||
@ -2332,7 +2351,9 @@ class KernelTemplate:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
|
||||
def indent_except_first(
|
||||
source: str, num_indents: int, indents_spacing: int = 4
|
||||
) -> str:
|
||||
lines = source.splitlines(True)
|
||||
if len(lines) > 1:
|
||||
lines[1:] = [
|
||||
@ -2341,64 +2362,74 @@ class KernelTemplate:
|
||||
return "".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _template_from_string(source):
|
||||
def _template_from_string(source: str) -> Any:
|
||||
env = jinja2_env()
|
||||
if env is None:
|
||||
return None
|
||||
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
||||
from jinja2 import TemplateSyntaxError
|
||||
|
||||
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
||||
def __init__(self, original_error):
|
||||
super().__init__(
|
||||
original_error.message,
|
||||
original_error.lineno,
|
||||
original_error.name,
|
||||
original_error.filename,
|
||||
)
|
||||
self.original_error = original_error
|
||||
|
||||
def __str__(self):
|
||||
error_info = f"Error in template at line {self.lineno}\n"
|
||||
error_info += f"Error message: {self.message}\n"
|
||||
if hasattr(self.original_error, "source"):
|
||||
lines = self.original_error.source.split("\n")
|
||||
error_info += "Context:\n"
|
||||
start = max(0, self.lineno - 2)
|
||||
end = min(len(lines), self.lineno + 2)
|
||||
for i in range(start, end):
|
||||
if i == self.lineno - 1:
|
||||
error_info += f"{i + 1}: --> {lines[i]}\n"
|
||||
if hasattr(self.original_error, "column"):
|
||||
error_info += (
|
||||
" "
|
||||
+ " " * (self.original_error.column - 1)
|
||||
+ "^\n"
|
||||
)
|
||||
else:
|
||||
error_info += f"{i + 1}: {lines[i]}\n"
|
||||
return error_info
|
||||
|
||||
try:
|
||||
return env.from_string(source)
|
||||
except TemplateSyntaxError as e:
|
||||
|
||||
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
||||
def __init__(self, original_error: TemplateSyntaxError) -> None:
|
||||
super().__init__(
|
||||
original_error.message,
|
||||
original_error.lineno,
|
||||
original_error.name,
|
||||
original_error.filename,
|
||||
)
|
||||
self.original_error = original_error
|
||||
|
||||
def __str__(self) -> str:
|
||||
error_info = f"Error in template at line {self.lineno}\n"
|
||||
error_info += f"Error message: {self.message}\n"
|
||||
if hasattr(self.original_error, "source"):
|
||||
lines = self.original_error.source.split("\n")
|
||||
error_info += "Context:\n"
|
||||
start = max(0, self.lineno - 2)
|
||||
end = min(len(lines), self.lineno + 2)
|
||||
for i in range(start, end):
|
||||
if i == self.lineno - 1:
|
||||
error_info += f"{i + 1}: --> {lines[i]}\n"
|
||||
if hasattr(self.original_error, "column"):
|
||||
error_info += (
|
||||
" "
|
||||
+ " " * (self.original_error.column - 1)
|
||||
+ "^\n"
|
||||
)
|
||||
else:
|
||||
error_info += f"{i + 1}: {lines[i]}\n"
|
||||
return error_info
|
||||
|
||||
raise DetailedTemplateSyntaxError(e) from e
|
||||
|
||||
@staticmethod
|
||||
def _fake_get_dtype(fake_out):
|
||||
def _fake_get_dtype(
|
||||
fake_outs: Union[list[Buffer], Buffer]
|
||||
) -> Callable[[str], torch.dtype]:
|
||||
_get_dtype_real = V.graph.get_dtype
|
||||
if isinstance(fake_outs, (list, tuple)):
|
||||
lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
|
||||
else:
|
||||
lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
|
||||
|
||||
def get_dtype(name):
|
||||
if name == fake_out.get_name():
|
||||
return fake_out.get_dtype()
|
||||
def get_dtype(name: str) -> torch.dtype:
|
||||
result = lookup.get(name)
|
||||
if result is not None:
|
||||
return result
|
||||
return _get_dtype_real(name)
|
||||
|
||||
return get_dtype
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def maybe_append_choice(self, choices, **kwargs):
|
||||
def maybe_append_choice(
|
||||
self, choices: list[Any], **kwargs: Any
|
||||
) -> Optional[NotImplementedError]:
|
||||
"""
|
||||
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
||||
Returns None if success, otherwise returns the error.
|
||||
@ -2413,7 +2444,7 @@ class KernelTemplate:
|
||||
except NotImplementedError as e:
|
||||
return e
|
||||
|
||||
def generate(self, **kwargs) -> torch._inductor.ir.ChoiceCaller:
|
||||
def generate(self, **kwargs: Any) -> ChoiceCaller:
|
||||
"""
|
||||
Generates a ChoiceCaller instance from the given arguments.
|
||||
"""
|
||||
|
@ -188,18 +188,6 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
for idx in range(gemm_grouped_num)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _fake_get_dtype(fake_outs: list[ir.Buffer]) -> Callable[[str], torch.dtype]:
|
||||
_get_dtype_real = V.graph.get_dtype
|
||||
|
||||
def get_dtype(name: str) -> torch.dtype:
|
||||
for fake_out in fake_outs:
|
||||
if name == fake_out.get_name():
|
||||
return fake_out.get_dtype()
|
||||
return _get_dtype_real(name)
|
||||
|
||||
return get_dtype
|
||||
|
||||
@classmethod
|
||||
def add_choices(
|
||||
cls,
|
||||
|
@ -1,4 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
@ -29,7 +31,7 @@ ReductionType = Literal[
|
||||
]
|
||||
|
||||
|
||||
def _arg_str(a) -> str:
|
||||
def _arg_str(a: object) -> str:
|
||||
if isinstance(a, sympy.Expr):
|
||||
return sympy_str(a)
|
||||
return str(a)
|
||||
@ -44,7 +46,7 @@ class OpsHandler(Protocol[T]):
|
||||
"""
|
||||
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
||||
as well as the contract for op handlers. The type T signifies the domain
|
||||
of the abstract analysis AKA what all of the functions return / take as arguments
|
||||
of the abstract analysis AKA what all the functions return / take as arguments
|
||||
anywhere compute occurs.
|
||||
|
||||
While these operators are typically dtype polymorphic (e.g., you can use mul
|
||||
@ -247,7 +249,7 @@ class OpsHandler(Protocol[T]):
|
||||
# TODO: in practice, this seems to actually return None, but not returning
|
||||
# a T makes common __getattr__ idioms not type correctly. Figure out if
|
||||
# this should be returning something.
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None:
|
||||
"""
|
||||
Store the fully accumulated result of 'reduction' to the memory
|
||||
location 'name' offset by 'expr'.
|
||||
@ -1046,7 +1048,7 @@ class ExtractConstantsHandler(NoopHandler):
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
|
||||
def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
|
||||
def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
|
||||
from torch._inductor import ir
|
||||
|
||||
return ir.Constant(value=value, dtype=dtype, device=self.device)
|
||||
|
Reference in New Issue
Block a user