[inductor] Finish typing common.py (#146225)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146225
Approved by: https://github.com/Skylion007
This commit is contained in:
Jason Ansel
2025-02-01 10:55:34 -08:00
committed by PyTorch MergeBot
parent dca5cc0255
commit 3a67c0e48d
3 changed files with 124 additions and 103 deletions

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
@ -32,7 +31,6 @@ import sympy
import torch
import torch.fx
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
@ -42,6 +40,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -57,7 +56,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from ..ir import FixedLayout, IRNode
from collections.abc import Sequence
from ..ir import Buffer, ChoiceCaller, FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
from .wrapper import PythonWrapperCodegen
@ -121,7 +122,7 @@ class WorkspaceArg:
dtype: torch.dtype = torch.uint8
@staticmethod
def unique_name(prefix="workspace_") -> str:
def unique_name(prefix: str = "workspace_") -> str:
return f"{prefix}{next(V.graph.workspace_id)}"
@staticmethod
@ -456,7 +457,9 @@ def init_backend_registration() -> None:
def index_prevent_reordering(
index: list[sympy.Expr], index_vars, sizes
index: Sequence[sympy.Expr],
index_vars: Sequence[sympy.Expr],
sizes: Sequence[sympy.Expr],
) -> list[sympy.Expr]:
from ..ir import FlexibleLayout
@ -631,11 +634,11 @@ class DataTypePropagation:
return self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body) -> Optional[torch.dtype]:
def propagate_loopbody(cls, body: LoopBody) -> Optional[torch.dtype]:
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node) -> Optional[torch.dtype]:
def propagate_scheduler_node(cls, node: SchedulerNode) -> Optional[torch.dtype]:
from ..loop_body import LoopBody
from ..scheduler import SchedulerNode
@ -673,7 +676,7 @@ class OpDecompositions:
return ops.mul(x, x)
@staticmethod
def erfc(x: OpVarT):
def erfc(x: OpVarT) -> OpVarT:
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
@staticmethod
@ -759,7 +762,7 @@ def _all_in_parens(string: str) -> bool:
class OpOverrides(OpDecompositions):
def __init__(self, parent):
def __init__(self, parent: OpsHandler[OpVarT]) -> None:
super().__init__()
self._parent = parent
@ -1109,12 +1112,12 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
def __init__(self, name, line):
def __init__(self, name: str, line: str):
super().__init__(line)
self.name = name
assert not isinstance(line, DeferredLineBase)
def __call__(self):
def __call__(self) -> Optional[str]:
if all(
self.name not in x
for x in (
@ -1127,14 +1130,14 @@ class DeferredLine(DeferredLineBase):
return self.line
return None
def _new_line(self, line):
def _new_line(self, line: str) -> DeferredLine:
return DeferredLine(self.name, line)
class BracesBuffer(IndentedBuffer):
def indent(self, offset=1) -> contextlib.AbstractContextManager[None]:
def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
@contextlib.contextmanager
def ctx():
def ctx() -> Iterator[None]:
for _ in range(offset):
self.writeline("{")
self._indent += 1
@ -1163,7 +1166,7 @@ class ArgName:
# is_constexpr=True is used to attach a " : tl.constexpr" into the argument list
is_constexpr: bool = False
def full_name(self):
def full_name(self) -> str:
return f"{self.name}{' : tl.constexpr' if self.is_constexpr else ''}"
@ -1509,8 +1512,8 @@ class CSEVariable:
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other) -> bool:
return type(other) == type(self) and other.name == self.name
def __eq__(self, other: object) -> bool:
return isinstance(other, CSEVariable) and other.name == self.name
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
pass
@ -1557,7 +1560,7 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
self.invalidated_stores: OrderedSet[str] = OrderedSet()
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
def invalidate(self, keep_vars: OrderedSet[CSEVariable]) -> None:
for name, tmp in [*self.store_cache.items()]:
if tmp not in keep_vars:
del self.store_cache[name]
@ -1578,6 +1581,14 @@ class CSE(Generic[CSEVariableType, AugmentedKeyT]):
reduction_cache=self.reduction_cache,
)
def scoped_copy(self) -> typing.Self:
"""Return a copy of using ScopedDict so changes to *_cache aren't visible in self"""
new_cse = self.clone()
new_cse._cache = ScopedDict(self._cache)
new_cse.reduction_cache = ScopedDict(self.reduction_cache)
new_cse.store_cache = ScopedDict(self.store_cache)
return new_cse
def augment_key(self, cache_key: str) -> AugmentedKeyT:
"Override this method to augment cache key with backend specifics"
return cast(AugmentedKeyT, cache_key)
@ -1730,7 +1741,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.kernel_name: Optional[str] = None
@contextlib.contextmanager
def set_current_node(self, node):
def set_current_node(self, node: SchedulerNode) -> Iterator[None]:
prior = self.current_node
self.current_node = node
self.node_to_bounds = node._body.bounds().get_bounds()
@ -1740,16 +1751,16 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.current_node = prior
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def scope_cse(cse: CSE[CSEVariableType, Any]):
new_cse = cse.clone()
new_cse._cache = ScopedDict(cse._cache)
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
new_cse.store_cache = ScopedDict(cse.store_cache)
return new_cse
def swap_buffers(
self,
lb: IndentedBuffer,
cb: Optional[IndentedBuffer] = None,
sb: Optional[IndentedBuffer] = None,
) -> Iterator[None]:
if cb is None:
cb = lb
if disallow_stores := sb is None:
sb = IndentedBuffer()
loads = self.loads
compute = self.compute
stores = self.stores
@ -1757,7 +1768,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.loads = lb
self.compute = cb
self.stores = sb
self.cse = scope_cse(cse)
self.cse = cse.scoped_copy()
try:
yield
finally:
@ -1765,11 +1776,13 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.compute = compute
self.stores = stores
self.cse = cse
if disallow_stores:
assert not sb, "unexpected store inside swap_buffers"
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
raise NotImplementedError
def indirect_load(self, name: str, index: sympy.Expr):
def indirect_load(self, name: str, index: sympy.Expr) -> CSEVariable:
"""A load the depends on an index we have read"""
prior = self.loads
try:
@ -1779,7 +1792,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
finally:
self.loads = prior
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
raise NotImplementedError
def store(
@ -1815,7 +1828,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> tuple[CSEVariable, ...]:
raise NotImplementedError
def var_ranges(self):
def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
raise NotImplementedError
def bucketize(
@ -1869,21 +1882,21 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
):
) -> None:
raise NotImplementedError
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError
def __enter__(self):
def __enter__(self) -> typing.Self:
# TODO: hoist this to top level
class CSEProxy:
self.name = "CSEProxy"
name = "CSEProxy"
vr_analysis = ValueRangeAnalysis()
@staticmethod
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args, **kwargs):
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
@ -1891,7 +1904,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
output_idx = 0
def do_cse(v):
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
@ -1950,7 +1963,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return inner
@staticmethod
def _bound_variable(name, *args, **kwargs):
def _bound_variable(
name: str, *args: Any, **kwargs: Any
) -> ValueRanges[Any]:
"""
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
@ -1979,7 +1994,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
# If there is no FX bound but we know how to compute one we do so
assert not kwargs
def arg_to_bound(x):
def arg_to_bound(x: Any) -> Any:
if isinstance(x, CSEVariable):
return x.bounds
elif isinstance(x, sympy.Expr):
@ -1996,8 +2011,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
var: CSEVariable,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg=True,
):
wrap_neg: bool = True,
) -> sympy.Symbol:
if isinstance(size, int):
size = sympy.Integer(size)
assert isinstance(size, sympy.Expr), size
@ -2045,7 +2060,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
@staticmethod
def check_bounds(
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
):
) -> None:
return self.check_bounds(expr, size, lower, upper)
@staticmethod
@ -2067,7 +2082,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return out
@staticmethod
def _update_store_cache(name: str, value: CSEVariable):
def _update_store_cache(name: str, value: CSEVariable) -> None:
value = cast(CSEVariableType, value)
self.cse.store_cache[name] = value
if self.current_node and name in V.graph.name_to_buffer:
@ -2087,7 +2102,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
return None # type: ignore[return-value]
@staticmethod
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
def store_reduction(
name: str, index: sympy.Expr, value: CSEVariable
) -> None:
self.store_buffer_names.add(name)
CSEProxy._update_store_cache(name, value)
@ -2215,7 +2232,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
@ -2271,7 +2288,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
self.args.inplace_buffers[name] = REMOVED
self.removed_buffers.add(name)
def rename_indexing(self, index) -> sympy.Expr:
def rename_indexing(
self, index: Union[list[sympy.Expr], tuple[sympy.Expr, ...], sympy.Expr]
) -> sympy.Expr:
# adds the necessary kernel args for index expressions
# and renames variables in index expressions to kernel arg names
if isinstance(index, (list, tuple)):
@ -2292,7 +2311,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
}
return sympy_subs(index, replacements)
def create_cse_var(self, *args, **kwargs):
def create_cse_var(self, *args: Any, **kwargs: Any) -> CSEVariable:
return CSEVariable(*args, **kwargs)
def arg_name(self, node: IRNode) -> Optional[str]:
@ -2313,7 +2332,7 @@ class OptimizationContext:
@functools.lru_cache(None)
def jinja2_env():
def jinja2_env() -> Any:
try:
import jinja2
@ -2332,7 +2351,9 @@ class KernelTemplate:
"""
@staticmethod
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
def indent_except_first(
source: str, num_indents: int, indents_spacing: int = 4
) -> str:
lines = source.splitlines(True)
if len(lines) > 1:
lines[1:] = [
@ -2341,64 +2362,74 @@ class KernelTemplate:
return "".join(lines)
@staticmethod
def _template_from_string(source):
def _template_from_string(source: str) -> Any:
env = jinja2_env()
if env is None:
return None
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
from jinja2 import TemplateSyntaxError
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error):
super().__init__(
original_error.message,
original_error.lineno,
original_error.name,
original_error.filename,
)
self.original_error = original_error
def __str__(self):
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)
end = min(len(lines), self.lineno + 2)
for i in range(start, end):
if i == self.lineno - 1:
error_info += f"{i + 1}: --> {lines[i]}\n"
if hasattr(self.original_error, "column"):
error_info += (
" "
+ " " * (self.original_error.column - 1)
+ "^\n"
)
else:
error_info += f"{i + 1}: {lines[i]}\n"
return error_info
try:
return env.from_string(source)
except TemplateSyntaxError as e:
class DetailedTemplateSyntaxError(TemplateSyntaxError):
def __init__(self, original_error: TemplateSyntaxError) -> None:
super().__init__(
original_error.message,
original_error.lineno,
original_error.name,
original_error.filename,
)
self.original_error = original_error
def __str__(self) -> str:
error_info = f"Error in template at line {self.lineno}\n"
error_info += f"Error message: {self.message}\n"
if hasattr(self.original_error, "source"):
lines = self.original_error.source.split("\n")
error_info += "Context:\n"
start = max(0, self.lineno - 2)
end = min(len(lines), self.lineno + 2)
for i in range(start, end):
if i == self.lineno - 1:
error_info += f"{i + 1}: --> {lines[i]}\n"
if hasattr(self.original_error, "column"):
error_info += (
" "
+ " " * (self.original_error.column - 1)
+ "^\n"
)
else:
error_info += f"{i + 1}: {lines[i]}\n"
return error_info
raise DetailedTemplateSyntaxError(e) from e
@staticmethod
def _fake_get_dtype(fake_out):
def _fake_get_dtype(
fake_outs: Union[list[Buffer], Buffer]
) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
if isinstance(fake_outs, (list, tuple)):
lookup = {buf.get_name(): buf.get_dtype() for buf in fake_outs}
else:
lookup = {fake_outs.get_name(): fake_outs.get_dtype()}
def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
def get_dtype(name: str) -> torch.dtype:
result = lookup.get(name)
if result is not None:
return result
return _get_dtype_real(name)
return get_dtype
def __init__(self, name: str):
def __init__(self, name: str) -> None:
self.name = name
def maybe_append_choice(self, choices, **kwargs):
def maybe_append_choice(
self, choices: list[Any], **kwargs: Any
) -> Optional[NotImplementedError]:
"""
Maybe generates a new ChoiceCaller and appends it into existing choices.
Returns None if success, otherwise returns the error.
@ -2413,7 +2444,7 @@ class KernelTemplate:
except NotImplementedError as e:
return e
def generate(self, **kwargs) -> torch._inductor.ir.ChoiceCaller:
def generate(self, **kwargs: Any) -> ChoiceCaller:
"""
Generates a ChoiceCaller instance from the given arguments.
"""

View File

@ -188,18 +188,6 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for idx in range(gemm_grouped_num)
]
@staticmethod
def _fake_get_dtype(fake_outs: list[ir.Buffer]) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
def get_dtype(name: str) -> torch.dtype:
for fake_out in fake_outs:
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
@classmethod
def add_choices(
cls,

View File

@ -1,4 +1,6 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import itertools
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from typing_extensions import Protocol
@ -29,7 +31,7 @@ ReductionType = Literal[
]
def _arg_str(a) -> str:
def _arg_str(a: object) -> str:
if isinstance(a, sympy.Expr):
return sympy_str(a)
return str(a)
@ -44,7 +46,7 @@ class OpsHandler(Protocol[T]):
"""
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
as well as the contract for op handlers. The type T signifies the domain
of the abstract analysis AKA what all of the functions return / take as arguments
of the abstract analysis AKA what all the functions return / take as arguments
anywhere compute occurs.
While these operators are typically dtype polymorphic (e.g., you can use mul
@ -247,7 +249,7 @@ class OpsHandler(Protocol[T]):
# TODO: in practice, this seems to actually return None, but not returning
# a T makes common __getattr__ idioms not type correctly. Figure out if
# this should be returning something.
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None:
"""
Store the fully accumulated result of 'reduction' to the memory
location 'name' offset by 'expr'.
@ -1046,7 +1048,7 @@ class ExtractConstantsHandler(NoopHandler):
def __init__(self, device):
self.device = device
def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
from torch._inductor import ir
return ir.Constant(value=value, dtype=dtype, device=self.device)