[inductor] Finish typing common.py (#146225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146225
Approved by: https://github.com/Skylion007
This commit is contained in:
Jason Ansel
2025-02-01 10:55:34 -08:00
committed by PyTorch MergeBot
parent dca5cc0255
commit 3a67c0e48d
3 changed files with 124 additions and 103 deletions

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
@ -32,7 +31,6 @@ import sympy
import torch import torch
import torch.fx import torch.fx
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
@ -42,6 +40,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .. import config, metrics from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..utils import ( from ..utils import (
boolean_ops, boolean_ops,
DeferredLineBase, DeferredLineBase,
@ -57,7 +56,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING: if TYPE_CHECKING:
from ..ir import FixedLayout, IRNode from collections.abc import Sequence
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
from .wrapper import PythonWrapperCodegen from .wrapper import PythonWrapperCodegen
@ -121,7 +122,7 @@ class WorkspaceArg:
dtype: torch.dtype = torch.uint8 dtype: torch.dtype = torch.uint8
@staticmethod @staticmethod
def unique_name(prefix="workspace_") -> str: def unique_name(prefix: str = "workspace_") -> str:
return f"{prefix}{next(V.graph.workspace_id)}" return f"{prefix}{next(V.graph.workspace_id)}"
@staticmethod @staticmethod
@ -456,7 +457,9 @@ def init_backend_registration() -> None:
def index_prevent_reordering( def index_prevent_reordering(
index: list[sympy.Expr], index_vars, sizes index: Sequence[sympy.Expr],
index_vars: Sequence[sympy.Expr],
sizes: Sequence[sympy.Expr],
) -> list[sympy.Expr]: ) -> list[sympy.Expr]:
from ..ir import FlexibleLayout from ..ir import FlexibleLayout
@ -631,11 +634,11 @@ class DataTypePropagation:
return self.propagate_graph(self.graphs["root"]) return self.propagate_graph(self.graphs["root"])
@classmethod @classmethod
def propagate_loopbody(cls, body) -> Optional[torch.dtype]: def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
return cls(body).propagate() return cls(body).propagate()
@classmethod @classmethod
def propagate_scheduler_node(cls, node) -> Optional[torch.dtype]: def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
from ..loop_body import LoopBody from ..loop_body import LoopBody
from ..scheduler import SchedulerNode from ..scheduler import SchedulerNode
@ -673,7 +676,7 @@ class OpDecompositions:
return ops.mul(x, x) return ops.mul(x, x)
@staticmethod @staticmethod
def erfc(x: OpVarT): def erfc(x: OpVarT) -> OpVarT:
return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
@staticmethod @staticmethod
@ -759,7 +762,7 @@ def _all_in_parens(string: str) -> bool:
class OpOverrides(OpDecompositions): class OpOverrides(OpDecompositions):
def __init__(self, parent): def __init__(self, parent: OpsHandler[OpVarT]) -> None:
super().__init__() super().__init__()
self._parent = parent self._parent = parent
@ -1109,12 +1112,12 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
class DeferredLine(DeferredLineBase): class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
def __init__(self, name, line): def __init__(self, name: str, line: str):
super().__init__(line) super().__init__(line)
self.name = name self.name = name
assert not isinstance(line, DeferredLineBase) assert not isinstance(line, DeferredLineBase)
def __call__(self): def __call__(self) -> Optional[str]:
if all( if all(
self.name not in x self.name not in x
for x in ( for x in (
@ -1127,14 +1130,14 @@ class DeferredLine(DeferredLineBase):
return self.line return self.line
return None return None
def _new_line(self, line): def _new_line(self, line: str) -> DeferredLine:
return DeferredLine(self.name, line) return DeferredLine(self.name, line)
class BracesBuffer(IndentedBuffer): class BracesBuffer(IndentedBuffer):
def indent(self, offset=1) -> contextlib.AbstractContextManager[None]: def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
@contextlib.contextmanager @contextlib.contextmanager
def ctx(): def ctx() -> Iterator[None]:
for _ in range(offset): for _ in range(offset):
self.writeline("{") self.writeline("{")
self._indent += 1 self._indent += 1
@ -1163,7 +1166,7 @@ class ArgName:
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list # is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
is_constexpr: bool = False is_constexpr: bool = False
def full_name(self): def full_name(self) -> str:
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}" return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
@ -1509,8 +1512,8 @@ class CSEVariable:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.name) return hash(self.name)
def __eq__(self, other) -> bool: def __eq__(self, other: object) -> bool:
return type(other) == type(self) and other.name == self.name return isinstance(other, CSEVariable) and other.name == self.name
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None: def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
pass pass
@ -1557,7 +1560,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
self.invalidated_stores: OrderedSet[str] = OrderedSet() self.invalidated_stores: OrderedSet[str] = OrderedSet()
self.varname_map: dict[str, CSEVariableType] = varname_map or {} self.varname_map: dict[str, CSEVariableType] = varname_map or {}
def invalidate(self, keep_vars: OrderedSet[CSEVariable]): def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
for name, tmp in [*self.store_cache.items()]: for name, tmp in [*self.store_cache.items()]:
if tmp not in keep_vars: if tmp not in keep_vars:
del self.store_cache[name] del self.store_cache[name]
@ -1578,6 +1581,14 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
reduction_cache=self.reduction_cache, reduction_cache=self.reduction_cache,
) )
def scoped_copy(self) -> typing.Self:
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
new_cse = self.clone()
new_cse._cache = ScopedDict(self._cache)
new_cse.reduction_cache = ScopedDict(self.reduction_cache)
new_cse.store_cache = ScopedDict(self.store_cache)
return new_cse
def augment_key(self, cache_key: str) -> AugmentedKeyT: def augment_key(self, cache_key: str) -> AugmentedKeyT:
"Override this method to augment cache key with backend specifics" "Override this method to augment cache key with backend specifics"
return cast(AugmentedKeyT, cache_key) return cast(AugmentedKeyT, cache_key)
@ -1730,7 +1741,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.kernel_name: Optional[str] = None self.kernel_name: Optional[str] = None
@contextlib.contextmanager @contextlib.contextmanager
def set_current_node(self, node): def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
prior = self.current_node prior = self.current_node
self.current_node = node self.current_node = node
self.node_to_bounds = node._body.bounds().get_bounds() self.node_to_bounds = node._body.bounds().get_bounds()
@ -1740,16 +1751,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.current_node = prior self.current_node = prior
@contextlib.contextmanager @contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None): def swap_buffers(
def scope_cse(cse: CSE[CSEVariableType, Any]): self,
new_cse = cse.clone() lb: IndentedBuffer,
new_cse._cache = ScopedDict(cse._cache) cb: Optional[IndentedBuffer] = None,
new_cse.reduction_cache = ScopedDict(cse.reduction_cache) sb: Optional[IndentedBuffer] = None,
new_cse.store_cache = ScopedDict(cse.store_cache) ) -> Iterator[None]:
return new_cse
if cb is None: if cb is None:
cb = lb cb = lb
if disallow_stores := sb is None:
sb = IndentedBuffer()
loads = self.loads loads = self.loads
compute = self.compute compute = self.compute
stores = self.stores stores = self.stores
@ -1757,7 +1768,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.loads = lb self.loads = lb
self.compute = cb self.compute = cb
self.stores = sb self.stores = sb
self.cse = scope_cse(cse) self.cse = cse.scoped_copy()
try: try:
yield yield
finally: finally:
@ -1765,11 +1776,13 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.compute = compute self.compute = compute
self.stores = stores self.stores = stores
self.cse = cse self.cse = cse
if disallow_stores:
assert not sb, "unexpected store inside swap_buffers"
def load(self, name: str, index: sympy.Expr) -> CSEVariable: def load(self, name: str, index: sympy.Expr) -> CSEVariable:
raise NotImplementedError raise NotImplementedError
def indirect_load(self, name: str, index: sympy.Expr): def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
"""A load the depends on an index we have read""" """A load the depends on an index we have read"""
prior = self.loads prior = self.loads
try: try:
@ -1779,7 +1792,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
finally: finally:
self.loads = prior self.loads = prior
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
raise NotImplementedError raise NotImplementedError
def store( def store(
@ -1815,7 +1828,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> tuple[CSEVariable, ...]: ) -> tuple[CSEVariable, ...]:
raise NotImplementedError raise NotImplementedError
def var_ranges(self): def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
raise NotImplementedError raise NotImplementedError
def bucketize( def bucketize(
@ -1869,21 +1882,21 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
def check_bounds( def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
): ) -> None:
raise NotImplementedError raise NotImplementedError
def index_to_str(self, index: sympy.Expr) -> str: def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError raise NotImplementedError
def __enter__(self): def __enter__(self) -> typing.Self:
# TODO: hoist this to top level # TODO: hoist this to top level
class CSEProxy: class CSEProxy:
self.name = "CSEProxy" name = "CSEProxy"
vr_analysis = ValueRangeAnalysis() vr_analysis = ValueRangeAnalysis()
@staticmethod @staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs): def inner(*args: Any, **kwargs: Any) -> CSEVariable:
bounds = CSEProxy._bound_variable(name, *args, **kwargs) bounds = CSEProxy._bound_variable(name, *args, **kwargs)
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
@ -1891,7 +1904,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
output_idx = 0 output_idx = 0
def do_cse(v): def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix # cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None: if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type device_str = V.graph.get_current_device_or_throw().type
@ -1950,7 +1963,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return inner return inner
@staticmethod @staticmethod
def _bound_variable(name, *args, **kwargs): def _bound_variable(
name: str, *args: Any, **kwargs: Any
) -> ValueRanges[Any]:
""" """
If the variable comes from an FX node, we forward the bound we have already computed If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds Else, if the variable when codegen'ing another op, we try to compute its bounds
@ -1979,7 +1994,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
# If there is no FX bound but we know how to compute one we do so # If there is no FX bound but we know how to compute one we do so
assert not kwargs assert not kwargs
def arg_to_bound(x): def arg_to_bound(x: Any) -> Any:
if isinstance(x, CSEVariable): if isinstance(x, CSEVariable):
return x.bounds return x.bounds
elif isinstance(x, sympy.Expr): elif isinstance(x, sympy.Expr):
@ -1996,8 +2011,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
var: CSEVariable, var: CSEVariable,
size: Union[sympy.Expr, int], size: Union[sympy.Expr, int],
check: bool = True, check: bool = True,
wrap_neg=True, wrap_neg: bool = True,
): ) -> sympy.Symbol:
if isinstance(size, int): if isinstance(size, int):
size = sympy.Integer(size) size = sympy.Integer(size)
assert isinstance(size, sympy.Expr), size assert isinstance(size, sympy.Expr), size
@ -2045,7 +2060,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
@staticmethod @staticmethod
def check_bounds( def check_bounds(
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
): ) -> None:
return self.check_bounds(expr, size, lower, upper) return self.check_bounds(expr, size, lower, upper)
@staticmethod @staticmethod
@ -2067,7 +2082,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return out return out
@staticmethod @staticmethod
def _update_store_cache(name: str, value: CSEVariable): def _update_store_cache(name: str, value: CSEVariable) -> None:
value = cast(CSEVariableType, value) value = cast(CSEVariableType, value)
self.cse.store_cache[name] = value self.cse.store_cache[name] = value
if self.current_node and name in V.graph.name_to_buffer: if self.current_node and name in V.graph.name_to_buffer:
@ -2087,7 +2102,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return None # type: ignore[return-value] return None # type: ignore[return-value]
@staticmethod @staticmethod
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): def store_reduction(
name: str, index: sympy.Expr, value: CSEVariable
) -> None:
self.store_buffer_names.add(name) self.store_buffer_names.add(name)
CSEProxy._update_store_cache(name, value) CSEProxy._update_store_cache(name, value)
@ -2215,7 +2232,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.exit_stack.enter_context(V.set_kernel_handler(self)) self.exit_stack.enter_context(V.set_kernel_handler(self))
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.remove_kernel_local_buffers() self.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb) super().__exit__(exc_type, exc_val, exc_tb)
@ -2271,7 +2288,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.args.inplace_buffers[name] = REMOVED self.args.inplace_buffers[name] = REMOVED
self.removed_buffers.add(name) self.removed_buffers.add(name)
def rename_indexing(self, index) -> sympy.Expr: def rename_indexing(
self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
) -> sympy.Expr:
# adds the necessary kernel args for index expressions # adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names # and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)): if isinstance(index, (list, tuple)):
@ -2292,7 +2311,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
} }
return sympy_subs(index, replacements) return sympy_subs(index, replacements)
def create_cse_var(self, *args, **kwargs): def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
return CSEVariable(*args, **kwargs) return CSEVariable(*args, **kwargs)
def arg_name(self, node: IRNode) -> Optional[str]: def arg_name(self, node: IRNode) -> Optional[str]:
@ -2313,7 +2332,7 @@ class OptimizationContext:
@functools.lru_cache(None) @functools.lru_cache(None)
def jinja2_env(): def jinja2_env() -> Any:
try: try:
import jinja2 import jinja2
@ -2332,7 +2351,9 @@ class KernelTemplate:
""" """
@staticmethod @staticmethod
def indent_except_first(source: str, num_indents: int, indents_spacing=4): def indent_except_first(
source: str, num_indents: int, indents_spacing: int = 4
) -> str:
lines = source.splitlines(True) lines = source.splitlines(True)
if len(lines) > 1: if len(lines) > 1:
lines[1:] = [ lines[1:] = [
@ -2341,64 +2362,74 @@ class KernelTemplate:
return "".join(lines) return "".join(lines)
@staticmethod @staticmethod
def _template_from_string(source): def _template_from_string(source: str) -> Any:
env = jinja2_env() env = jinja2_env()
if env is None: if env is None:
return None return None
env.filters["indent_except_first"] = KernelTemplate.indent_except_first env.filters["indent_except_first"] = KernelTemplate.indent_except_first
from jinja2 import TemplateSyntaxError from jinja2 import TemplateSyntaxError
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error):
super().__init__(
original_error.message,
original_error.lineno,
original_error.name,
original_error.filename,
)
self.original_error = original_error
def __str__(self):
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)
end = min(len(lines), self.lineno + 2)
for i in range(start, end):
if i == self.lineno - 1:
error_info += f"{i + 1}: --> {lines[i]}\n"
if hasattr(self.original_error, "column"):
error_info += (
" "
+ " " * (self.original_error.column - 1)
+ "^\n"
)
else:
error_info += f"{i + 1}: {lines[i]}\n"
return error_info
try: try:
return env.from_string(source) return env.from_string(source)
except TemplateSyntaxError as e: except TemplateSyntaxError as e:
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error: TemplateSyntaxError) -> None:
super().__init__(
original_error.message,
original_error.lineno,
original_error.name,
original_error.filename,
)
self.original_error = original_error
def __str__(self) -> str:
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)
end = min(len(lines), self.lineno + 2)
for i in range(start, end):
if i == self.lineno - 1:
error_info += f"{i + 1}: --> {lines[i]}\n"
if hasattr(self.original_error, "column"):
error_info += (
" "
+ " " * (self.original_error.column - 1)
+ "^\n"
)
else:
error_info += f"{i + 1}: {lines[i]}\n"
return error_info
raise DetailedTemplateSyntaxError(e) from e raise DetailedTemplateSyntaxError(e) from e
@staticmethod @staticmethod
def _fake_get_dtype(fake_out): def _fake_get_dtype(
fake_outs: Union[list[Buffer], Buffer]
) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype _get_dtype_real = V.graph.get_dtype
if isinstance(fake_outs, (list, tuple)):
lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
else:
lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
def get_dtype(name): def get_dtype(name: str) -> torch.dtype:
if name == fake_out.get_name(): result = lookup.get(name)
return fake_out.get_dtype() if result is not None:
return result
return _get_dtype_real(name) return _get_dtype_real(name)
return get_dtype return get_dtype
def __init__(self, name: str): def __init__(self, name: str) -> None:
self.name = name self.name = name
def maybe_append_choice(self, choices, **kwargs): def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:
""" """
Maybe generates a new ChoiceCaller and appends it into existing choices. Maybe generates a new ChoiceCaller and appends it into existing choices.
Returns None if success, otherwise returns the error. Returns None if success, otherwise returns the error.
@ -2413,7 +2444,7 @@ class KernelTemplate:
except NotImplementedError as e: except NotImplementedError as e:
return e return e
def generate(self, **kwargs) -> torch._inductor.ir.ChoiceCaller: def generate(self, **kwargs: Any) -> ChoiceCaller:
""" """
Generates a ChoiceCaller instance from the given arguments. Generates a ChoiceCaller instance from the given arguments.
""" """

View File

@ -188,18 +188,6 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for idx in range(gemm_grouped_num) for idx in range(gemm_grouped_num)
] ]
@staticmethod
def _fake_get_dtype(fake_outs: list[ir.Buffer]) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
def get_dtype(name: str) -> torch.dtype:
for fake_out in fake_outs:
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
@classmethod @classmethod
def add_choices( def add_choices(
cls, cls,

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from __future__ import annotations
import itertools import itertools
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from typing_extensions import Protocol from typing_extensions import Protocol
@ -29,7 +31,7 @@ ReductionType = Literal[
] ]
def _arg_str(a) -> str: def _arg_str(a: object) -> str:
if isinstance(a, sympy.Expr): if isinstance(a, sympy.Expr):
return sympy_str(a) return sympy_str(a)
return str(a) return str(a)
@ -44,7 +46,7 @@ class OpsHandler(Protocol[T]):
""" """
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
as well as the contract for op handlers. The type T signifies the domain as well as the contract for op handlers. The type T signifies the domain
of the abstract analysis AKA what all of the functions return / take as arguments of the abstract analysis AKA what all the functions return / take as arguments
anywhere compute occurs. anywhere compute occurs.
While these operators are typically dtype polymorphic (e.g., you can use mul While these operators are typically dtype polymorphic (e.g., you can use mul
@ -247,7 +249,7 @@ class OpsHandler(Protocol[T]):
# TODO: in practice, this seems to actually return None, but not returning # TODO: in practice, this seems to actually return None, but not returning
# a T makes common __getattr__ idioms not type correctly. Figure out if # a T makes common __getattr__ idioms not type correctly. Figure out if
# this should be returning something. # this should be returning something.
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T: def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None:
""" """
Store the fully accumulated result of 'reduction' to the memory Store the fully accumulated result of 'reduction' to the memory
location 'name' offset by 'expr'. location 'name' offset by 'expr'.
@ -1046,7 +1048,7 @@ class ExtractConstantsHandler(NoopHandler):
def __init__(self, device): def __init__(self, device):
self.device = device self.device = device
def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant": def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
from torch._inductor import ir from torch._inductor import ir
return ir.Constant(value=value, dtype=dtype, device=self.device) return ir.Constant(value=value, dtype=dtype, device=self.device)