[inductor] Refactor op handlers part 5 (#146257)

This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254, #146255
This commit is contained in:
Jason Ansel
2025-02-07 13:32:54 -08:00
committed by PyTorch MergeBot
parent 403db2faee
commit 06604c4ec1
16 changed files with 332 additions and 455 deletions

View File

@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
from torch._inductor.codegen.halide import HalideOverrides
from torch._inductor.codegen.mps import MetalOverrides
from torch._inductor.codegen.triton import TritonKernelOverrides
from torch._inductor.ops_handler import list_ops, OP_NAMES
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
from torch._inductor.test_case import TestCase
class TestOpCompleteness(TestCase):
def verify_ops_handler_completeness(self, handler):
op_names = list_ops(handler)
if OP_NAMES == op_names:
return
print(f"Missing ops: {OP_NAMES - op_names}")
print(f"Extra ops: {op_names - OP_NAMES}")
self.assertEqual(", ".join(OP_NAMES - op_names), "")
self.assertEqual(", ".join(op_names - OP_NAMES), "")
for op in OP_NAMES:
self.assertIsNot(
getattr(handler, op),
getattr(OpsHandler, op),
msg=f"{handler} must implement {op}",
)
extra_ops = list_ops(handler) - OP_NAMES
if extra_ops:
raise AssertionError(
f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
)
def test_triton_overrides(self):
self.verify_ops_handler_completeness(TritonKernelOverrides)

View File

@ -34,7 +34,8 @@ from torch.utils._sympy.reference import (
)
from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import ValueRanges
from torch._inductor.bounds import ValueRangeAnalysis
UNARY_OPS = [

View File

@ -9,7 +9,7 @@ from torch._inductor import config
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._inductor.index_propagation import SymPyOps, TypedExpr
from .ops_handler import DefaultHandler, OpsHandler
from .ops_handler import DefaultHandler
from .virtualized import StoreMode, V
@ -32,27 +32,22 @@ class PreservesZeros(SymPyOps, DefaultHandler):
self.store_preserves_zeros: Optional[bool] = None
self.dtype_prop = DtypePropagationOpsHandler()
@staticmethod
def load(name: str, index: sympy.Expr) -> TypedExpr:
def load(self, name: str, index: sympy.Expr) -> TypedExpr:
# In prologue fusion, all loads get broadcasted
dtype = V.get_ops_handler().dtype_prop.load(name, index)
dtype = self.dtype_prop.load(name, index)
return TypedExpr(
sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype
)
@staticmethod
def store(
name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
) -> None:
self = V.get_ops_handler()
assert isinstance(self, PreservesZeros)
# should only have a single store in prologue
assert self.store_preserves_zeros is None
self.store_preserves_zeros = value.is_constant() and value.expr == 0
@staticmethod
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
self = V.get_ops_handler()
def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr:
return construct_symbol(next(self.count), torch.int32)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
@ -65,12 +60,6 @@ class PreservesZeros(SymPyOps, DefaultHandler):
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
if TYPE_CHECKING:
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
pass
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
"""
Does this prologue preserve zero masks
@ -100,9 +89,8 @@ class RecordLowPrecisionOps(DefaultHandler):
"constant",
)
@staticmethod
def load(name: str, index: sympy.Expr) -> DTypeContainer:
return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index))
def load(self, name: str, index: sympy.Expr) -> DTypeContainer:
return DTypeContainer(self.dtype_prop.load(name, index))
@staticmethod
def store(
@ -131,12 +119,6 @@ class RecordLowPrecisionOps(DefaultHandler):
return out
if TYPE_CHECKING:
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
pass
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4

View File

@ -1,14 +1,22 @@
import logging
import operator
from functools import partial
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union
import sympy
from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import (
bound_sympy,
SymPyValueRangeAnalysis,
ValueRanges,
)
from ..utils._sympy.functions import PowByNatural
from ..utils._sympy.numbers import int_oo
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
from .ops_handler import DefaultHandler, ReductionType, StoreMode
from .utils import cache_on_self, dominated_nodes
from .virtualized import V
@ -139,3 +147,113 @@ class BoundVars:
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
self.replacement_vals[name] = bound
return bound
class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
def __init__(self) -> None:
self.name = "ValueRangeAnalysis"
boolean_operators = (
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges.unknown()
def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
return ValueRanges.unknown()
def store(
self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
) -> None:
return
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Any,
) -> ValueRanges[Any]:
return ValueRanges.unknown()
@classmethod
def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
assert isinstance(index, ValueRanges)
return cls.to_dtype(index, dtype)
@staticmethod
def to_dtype(
x: Any,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> ValueRanges[Any]:
x = ValueRanges.wrap(x)
if dtype == torch.bool:
if x.is_singleton():
return ValueRanges.wrap(x.lower != 0)
elif x.is_bool:
return x
elif 0 not in x:
return ValueRanges.wrap(sympy.true)
else:
return ValueRanges(sympy.false, sympy.true)
def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
# dtype is int or float
if dtype.is_floating_point:
return sympy.Float(x)
else:
if x in (int_oo, -int_oo):
return x
try:
return sympy.Integer(x)
except TypeError:
# inf cannot be cast to Integer
return x
if x.is_bool:
if x.is_singleton():
val = 1 if x.lower else 0
return ValueRanges.wrap(cast(val, dtype))
else:
return ValueRanges(cast(0, dtype), cast(1, dtype))
else:
# int to float or float to int
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
@staticmethod
def square(x: Any) -> ValueRanges[Any]:
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
@staticmethod
def neg(x: Any) -> ValueRanges[Any]:
return ValueRanges.decreasing_map(x, operator.neg)
# TODO: this is slightly inaccurate because truncdiv operates at integer
# precision, but we're going through float truediv which means we can
# potentially lose precision on the bounds
@classmethod
def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
x = cls.truediv(a, b)
if x == ValueRanges.unknown():
return x
return cls.trunc(x)
@classmethod
def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
return cls.add(a, cls.neg(b))

View File

@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOps, DefaultHandler
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool:
return True
class OpOverrides(BasicMathOps, OpDecompositions):
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
@staticmethod
def paren(string: OpVarT) -> OpVarT:
if (
@ -1235,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
)
if TYPE_CHECKING:
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
@ -2268,6 +2262,8 @@ class CSEProxy(DefaultHandler):
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
super().__init__()
from ..bounds import ValueRangeAnalysis
self.vr_analysis = ValueRangeAnalysis()
self.kernel = kernel
self.parent_handler = parent_handler
@ -2338,6 +2334,7 @@ class CSEProxy(DefaultHandler):
If the variable comes from an FX node, we forward the bound we have already computed
Else, if the variable when codegen'ing another op, we try to compute its bounds
"""
from ..bounds import ValueRangeAnalysis
from ..select_algorithm import TritonTemplateKernel
if isinstance(V.kernel, TritonTemplateKernel):
@ -2575,9 +2572,3 @@ class CSEProxy(DefaultHandler):
sorter,
sorter_indices,
)
if TYPE_CHECKING:
class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]):
pass

View File

@ -33,7 +33,7 @@ from ..utils import (
sympy_index_symbol,
sympy_subs,
)
from ..virtualized import _ops as ops, OpsHandler, V
from ..virtualized import _ops as ops, V
from .common import (
BackendFeature,
CSEVariable,
@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides):
HalideOverrides._initialize_pointwise_overrides("halide")
if TYPE_CHECKING:
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class HalideCSEVariable(CSEVariable):
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")

View File

@ -29,7 +29,7 @@ if TYPE_CHECKING:
import sympy
from ..ops_handler import OpsHandler, ReductionType, StoreMode
from ..ops_handler import ReductionType, StoreMode
from ..scheduler import Scheduler, SchedulerNode
from .common import OpVarT
@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides):
MetalOverrides._initialize_pointwise_overrides("mps")
if TYPE_CHECKING:
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
class MetalKernel(SIMDKernel):
overrides = MetalOverrides # type: ignore[assignment]
suffix = ";"

View File

@ -61,7 +61,7 @@ from ..utils import (
triton_version_uses_attrs_dict,
upcast_compute_type,
)
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
from ..virtualized import _ops as ops, ReductionType, StoreMode, V
from ..wrapper_benchmark import get_kernel_category_by_source_code
from .block_analysis import BlockPatternMatcher
from .common import (
@ -1428,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides):
return (mantissa, exponent)
if TYPE_CHECKING:
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class HelperFunctions:
"""An ordered set of helper functions."""

View File

@ -4,17 +4,7 @@ import itertools
import logging
import re
from collections.abc import Sequence
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from unittest.mock import patch
import sympy
@ -34,7 +24,7 @@ from .utils import (
sympy_subs,
VarRanges,
)
from .virtualized import OpsHandler, ReductionType, V
from .virtualized import ReductionType, V
T = TypeVar("T")
@ -799,14 +789,6 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
body()
if TYPE_CHECKING:
class _typecheck_FreeUnbackedSymbolsOpsHandler(
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
):
pass
def extract_free_unbacked_symbols(
fn: Callable[..., Any],
index: Sequence[sympy.Expr],

View File

@ -23,7 +23,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
from typing import Any, Literal, Optional, overload, Union
from typing_extensions import TypeAlias
import sympy
@ -33,7 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .ops_handler import DefaultHandler, OpsHandler
from .ops_handler import DefaultHandler
from .sizevars import evaluate_expr
from .utils import generate_assert
from .virtualized import V
@ -370,9 +370,3 @@ class IndexPropagation(DefaultHandler):
"indirect_indexing", (index, size, check, wrap_neg), {}
).value
return indirect_var
if TYPE_CHECKING:
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
pass

View File

@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .ops_handler import DefaultHandler, OpsHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
@ -655,7 +655,7 @@ class LoopBodyBlock:
class CountOps(DefaultHandler):
def __init__(self, inner: Any, counts: collections.Counter[str]):
def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]):
self._inner = inner
self._counts = counts

View File

@ -4,17 +4,7 @@ from __future__ import annotations
import itertools
import re
import warnings
from typing import (
Any,
Callable,
Literal,
NamedTuple,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Protocol
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from unittest.mock import patch
import sympy
@ -48,12 +38,8 @@ def _arg_str(a: object) -> str:
return str(a)
# NB: This is not done as a parent class, because our ops handlers
# implementations make heavy use of __getattr__ magic, and pre-existing
# stubs for methods would interfere with this mechanism.
#
# See OpDecompositions for superclass that desugars operations like reciprocal/square.
class OpsHandler(Protocol[T]):
class OpsHandler(Generic[T]):
"""
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
as well as the contract for op handlers. The type T signifies the domain
@ -77,49 +63,30 @@ class OpsHandler(Protocol[T]):
ops handlers.
Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides),
which means you will get type errors if you subclass OpsHandler since mypy doesn't know
about the methods added via metaprogramming and thinks the class is still abstract.
Instead, you should add a block like:
if TYPE_CHECKING:
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
Which will check the signatures of non-meta-programmed methods and gives decent error messages.
Some older parts of the code use a pattern like:
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
return h
This pattern only works if the class defines a __getattr__ method, which we are moving away from.
Additionally, this pattern generates horrible error messages if the signatures are wrong.
It gives zero information about what the problem is, which makes the pattern harmful.
Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all
operators are implemented after all the metaprogramming has run.
which means you will not get type errors for those methods. We have tests in
test/inductor/test_op_completeness.py which check that all operators are implemented after
all the metaprogramming has run.
"""
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
"""Produces a scalar constant of type dtype."""
...
raise NotImplementedError
def load_seed(self, name: str, offset: T) -> T:
"""Computes inductor_prims.lookup_seed."""
...
raise NotImplementedError
def rand(self, seed: T, offset: T) -> T:
"""Computes inductor_prims.random with mode="rand". offset has dtype int32."""
...
raise NotImplementedError
def randn(self, seed: T, offset: T) -> T:
"""Computes inductor_prims.random with mode="randn". offset has dtype int32."""
...
raise NotImplementedError
def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
"""Computes inductor_prims.randint. offset has dtype int32."""
...
raise NotImplementedError
def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
"""
@ -133,13 +100,13 @@ class OpsHandler(Protocol[T]):
Contrast this with ops.where, which can multiplex between two values
that have been unconditionally computed.
"""
...
raise NotImplementedError
def where(self, condition: T, input: T, other: T) -> T:
"""
Computes torch.where: when condition is true, return input; otherwise return other.
"""
...
raise NotImplementedError
def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
"""
@ -147,7 +114,7 @@ class OpsHandler(Protocol[T]):
an indexing expression, thus the name; however, it can also be used in
non-indexing situations.
"""
...
raise NotImplementedError
def to_dtype(
self,
@ -160,7 +127,7 @@ class OpsHandler(Protocol[T]):
Convert x to dtype. src_dtype can be optionally set to specify what the original
dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
"""
...
raise NotImplementedError
def trunc_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
@ -174,38 +141,38 @@ class OpsHandler(Protocol[T]):
int64 depending on if we've shown that all the indexing operations can
be done in int32.
"""
...
raise NotImplementedError
def ceil_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with ceiling semantics. See also trunc_to_int.
"""
...
raise NotImplementedError
def floor_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with ceiling semantics. See also trunc_to_int.
"""
...
raise NotImplementedError
def round_to_int(self, x: T, dtype: torch.dtype) -> T:
"""
Convert x to dtype with round-to-even semantics. See also trunc_to_int.
"""
...
raise NotImplementedError
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
"""
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
src_dtype must be the original type of x.
"""
...
raise NotImplementedError
def identity(self, x: T) -> T:
"""
Returns x as is. This is used to trigger CSE.
"""
...
raise NotImplementedError
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# These operations are only available in a "kernel" context. Check
@ -227,13 +194,13 @@ class OpsHandler(Protocol[T]):
NB: This is typically mandatory to implement for any analysis, because you
MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
"""
...
raise NotImplementedError
def load(self, name: str, index: sympy.Expr) -> T:
"""
Load from the memory location 'name', offset by some indexing expression 'index'.
"""
...
raise NotImplementedError
def store(
self,
@ -246,7 +213,7 @@ class OpsHandler(Protocol[T]):
Store 'value' to the memory location 'name' offset by 'expr'. If
specified, 'mode' can require the store to be an atomic addition.
"""
...
raise NotImplementedError
# TODO: Better explain how the "collective" semantics of these ops;
# remember that the input value is a scalar, you can't reduce on it in the
@ -268,7 +235,7 @@ class OpsHandler(Protocol[T]):
function returns multiple outputs; consult reduction_num_outputs to
determine the amount in metaprogramming applications.
"""
...
raise NotImplementedError
# TODO: in practice, this seems to actually return None, but not returning
# a T makes common __getattr__ idioms not type correctly. Figure out if
@ -278,7 +245,7 @@ class OpsHandler(Protocol[T]):
Store the fully accumulated result of 'reduction' to the memory
location 'name' offset by 'expr'.
"""
...
raise NotImplementedError
def scan(
self,
@ -290,7 +257,7 @@ class OpsHandler(Protocol[T]):
Perform an associative scan on 'value'.
"""
# TODO: Improve the description with some pseudocode
...
raise NotImplementedError
def sort(
self,
@ -302,7 +269,7 @@ class OpsHandler(Protocol[T]):
"""
Sort values along the reduction dimension.
"""
...
raise NotImplementedError
def bucketize(
self,
@ -315,231 +282,231 @@ class OpsHandler(Protocol[T]):
sorter_indices: Optional[T] = None,
) -> T:
# See [Note: Inductor bucketize op]
...
raise NotImplementedError
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# The following ops have semantics that correspond exactly to the torch
# operation with the same corresponding name.
def abs(self, x0: T) -> T:
...
raise NotImplementedError
def exp(self, x0: T) -> T:
...
raise NotImplementedError
def exp2(self, x0: T) -> T:
...
raise NotImplementedError
def expm1(self, x0: T) -> T:
...
raise NotImplementedError
def sqrt(self, x0: T) -> T:
...
raise NotImplementedError
def relu(self, x0: T) -> T:
...
raise NotImplementedError
def minimum(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def maximum(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def cos(self, x0: T) -> T:
...
raise NotImplementedError
def sin(self, x0: T) -> T:
...
raise NotImplementedError
def lgamma(self, x0: T) -> T:
...
raise NotImplementedError
def erf(self, x0: T) -> T:
...
raise NotImplementedError
def cosh(self, x0: T) -> T:
...
raise NotImplementedError
def sinh(self, x0: T) -> T:
...
raise NotImplementedError
def acos(self, x0: T) -> T:
...
raise NotImplementedError
def acosh(self, x0: T) -> T:
...
raise NotImplementedError
def asin(self, x0: T) -> T:
...
raise NotImplementedError
def asinh(self, x0: T) -> T:
...
raise NotImplementedError
def atan2(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def atan(self, x0: T) -> T:
...
raise NotImplementedError
def atanh(self, x0: T) -> T:
...
raise NotImplementedError
def copysign(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def erfc(self, x0: T) -> T:
...
raise NotImplementedError
def erfinv(self, x0: T) -> T:
...
raise NotImplementedError
def frexp(self, x0: T):
...
raise NotImplementedError
def hypot(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def log10(self, x0: T) -> T:
...
raise NotImplementedError
def log2(self, x0: T) -> T:
...
raise NotImplementedError
def nextafter(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def logical_and(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def logical_not(self, x0: T) -> T:
...
raise NotImplementedError
def logical_or(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def logical_xor(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def bitwise_and(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def bitwise_not(self, x0: T) -> T:
...
raise NotImplementedError
def bitwise_or(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def bitwise_xor(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def bitwise_left_shift(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def bitwise_right_shift(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def rsqrt(self, x0: T) -> T:
...
raise NotImplementedError
def log1p(self, x0: T) -> T:
...
raise NotImplementedError
def tan(self, x0: T) -> T:
...
raise NotImplementedError
def tanh(self, x0: T) -> T:
...
raise NotImplementedError
def sigmoid(self, x0: T) -> T:
...
raise NotImplementedError
def signbit(self, x0: T) -> T:
...
raise NotImplementedError
def fmod(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def log(self, x0: T) -> T:
...
raise NotImplementedError
def isinf(self, x0: T) -> T:
...
raise NotImplementedError
def isnan(self, x0: T) -> T:
...
raise NotImplementedError
# NB: this returns a float, like the torch operation
# This rounds half to even to break ties
def round(self, x0: T) -> T:
...
raise NotImplementedError
# NB: this returns a float, like the torch operation
def floor(self, x0: T) -> T:
...
raise NotImplementedError
def sign(self, x0: T) -> T:
...
raise NotImplementedError
# NB: this returns a float, like the torch operation
def trunc(self, x0: T) -> T:
...
raise NotImplementedError
# NB: this returns a float, like the torch operation
def ceil(self, x0: T) -> T:
...
raise NotImplementedError
def neg(self, x0: T) -> T:
...
raise NotImplementedError
def reciprocal(self, x0: T) -> T:
...
raise NotImplementedError
def eq(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def ne(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def lt(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def gt(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def le(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def ge(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def add(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def sub(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def mul(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
# NB: this returns a float, like the torch operation
def pow(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def and_(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def or_(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def xor(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
# These are metaprogrammed by MockHandler._init_cls
def lshift(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
def rshift(self, x0: T, x1: T) -> T:
...
raise NotImplementedError
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# These are "special" operators. These only exist if the target
@ -547,124 +514,124 @@ class OpsHandler(Protocol[T]):
# pointwise_overrides_data.
def airy_ai(self, x: T) -> T:
...
raise NotImplementedError
def bessel_j0(self, x: T) -> T:
...
raise NotImplementedError
def bessel_j1(self, x: T) -> T:
...
raise NotImplementedError
def bessel_y0(self, x: T) -> T:
...
raise NotImplementedError
def bessel_y1(self, x: T) -> T:
...
raise NotImplementedError
def digamma(self, x: T) -> T:
...
raise NotImplementedError
def erfcx(self, x: T) -> T:
...
raise NotImplementedError
def fma(self, x: T, y: T, z: T) -> T:
...
raise NotImplementedError
def igamma(self, x: T, y: T) -> T:
...
raise NotImplementedError
def igammac(self, x: T, y: T) -> T:
...
raise NotImplementedError
def gammainc(self, x: T, y: T) -> T:
...
raise NotImplementedError
def gammaincc(self, x: T, y: T) -> T:
...
raise NotImplementedError
def i0(self, x: T) -> T:
...
raise NotImplementedError
def i0e(self, x: T) -> T:
...
raise NotImplementedError
def i1(self, x: T) -> T:
...
raise NotImplementedError
def i1e(self, x: T) -> T:
...
raise NotImplementedError
def log_ndtr(self, x: T) -> T:
...
raise NotImplementedError
def modified_bessel_i0(self, x: T) -> T:
...
raise NotImplementedError
def modified_bessel_i1(self, x: T) -> T:
...
raise NotImplementedError
def modified_bessel_k0(self, x: T) -> T:
...
raise NotImplementedError
def modified_bessel_k1(self, x: T) -> T:
...
raise NotImplementedError
def ndtr(self, x: T) -> T:
...
raise NotImplementedError
def ndtri(self, x: T) -> T:
...
raise NotImplementedError
def polygamma(self, x: T, y: T) -> T:
...
raise NotImplementedError
def scaled_modified_bessel_k0(self, x: T) -> T:
...
raise NotImplementedError
def scaled_modified_bessel_k1(self, x: T) -> T:
...
raise NotImplementedError
def spherical_bessel_j0(self, x: T) -> T:
...
raise NotImplementedError
def zeta(self, x: T, y: T) -> T:
...
raise NotImplementedError
def chebyshev_polynomial_t(self, x: T, y: T) -> T:
...
raise NotImplementedError
def chebyshev_polynomial_u(self, x: T, y: T) -> T:
...
raise NotImplementedError
def chebyshev_polynomial_v(self, x: T, y: T) -> T:
...
raise NotImplementedError
def chebyshev_polynomial_w(self, x: T, y: T) -> T:
...
raise NotImplementedError
def legendre_polynomial_p(self, x: T, y: T) -> T:
...
raise NotImplementedError
def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T:
...
raise NotImplementedError
def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T:
...
raise NotImplementedError
def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T:
...
raise NotImplementedError
def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T:
...
raise NotImplementedError
def hermite_polynomial_h(self, x: T, y: T) -> T:
...
raise NotImplementedError
def hermite_polynomial_he(self, x: T, y: T) -> T:
...
raise NotImplementedError
def laguerre_polynomial_l(self, x: T, y: T) -> T:
...
raise NotImplementedError
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# These operators are a bit special, because they are conventionally
@ -675,42 +642,42 @@ class OpsHandler(Protocol[T]):
"""C-style trunc division between integers only. Computes the true
division of two numbers and rounds the result to zero.
"""
...
raise NotImplementedError
def floordiv(self, x0: T, x1: T) -> T:
"""Python-style floor division between integers only. Computes the
true division of two numbers and floors the result. If you want
floor division for floats, do regular truediv and floor the result.
"""
...
raise NotImplementedError
def truediv(self, x0: T, x1: T) -> T:
"""True division between floats. Integer inputs are NOT valid. To
do Python-style (int, int) -> float division, use int_truediv"""
...
raise NotImplementedError
def int_truediv(self, x0: T, x1: T) -> T:
"""True division between integers. This is NOT the same as promoting
to float and doing integer division, there is a bespoke algorithm for
doing the division in higher precision than the above.
"""
...
raise NotImplementedError
def mod(self, x0: T, x1: T) -> T:
"""C-style modulus, take sign from LHS (x0)."""
...
raise NotImplementedError
def remainder(self, x0: T, x1: T) -> T:
"""Python-style modulus, take sign from RHS (x1)."""
...
raise NotImplementedError
def square(self, x0: T) -> T:
...
raise NotImplementedError
def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
) -> None:
...
raise NotImplementedError
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# In CUDA, optimized implementations of other mathematical operations are
@ -726,25 +693,25 @@ class OpsHandler(Protocol[T]):
# for many analyses it's not conveniently available.)
def libdevice_abs(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_exp(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_sqrt(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_cos(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_sin(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_sigmoid(self, x0: T) -> T:
...
raise NotImplementedError
def libdevice_log(self, x0: T) -> T:
...
raise NotImplementedError
# halide-only
def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T:
@ -760,15 +727,15 @@ class OpsHandler(Protocol[T]):
is_pure: bool = True,
pack: int = 1,
) -> T:
...
raise NotImplementedError
def output(self, x0: T) -> None:
"""This is a fake op used in analysis but not codegen"""
...
raise NotImplementedError
def placeholder(self, index: int) -> T:
"""This is a fake op used in analysis but not codegen"""
...
raise NotImplementedError
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
@ -781,7 +748,7 @@ def list_ops(cls: type[Any]):
OP_NAMES = list_ops(OpsHandler)
class DefaultHandler:
class DefaultHandler(OpsHandler[Any]):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
"""
Default implementation for all ops. Override in a subclass to
@ -850,13 +817,7 @@ class NoopHandler(DefaultHandler):
return sympy.S.Zero
if TYPE_CHECKING:
class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]):
pass # mypy will error if we got any of the signatures wrong
class BasicMathOps:
class BasicMathOpsMixin:
@staticmethod
def add(a, b):
return f"{a} + {b}"
@ -935,7 +896,7 @@ class BasicMathOps:
return f"-{a}"
class MockHandler(BasicMathOps, DefaultHandler):
class MockHandler(BasicMathOpsMixin, DefaultHandler):
name = "MockHandler"
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
@ -971,14 +932,8 @@ class MockHandler(BasicMathOps, DefaultHandler):
return sympy_index_symbol(str(index_var))
if TYPE_CHECKING:
class _typecheck_MockHandler(MockHandler, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class KernelFormatterHandler(DefaultHandler):
def __init__(self, parent_handler):
def __init__(self, parent_handler: OpsHandler[Any]):
self.parent_handler = parent_handler
self._output = IndentedBuffer(1)
self.var_counter = itertools.count()
@ -1042,14 +997,8 @@ class KernelFormatterHandler(DefaultHandler):
return self._output.getvalue()
if TYPE_CHECKING:
class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class WrapperHandler(DefaultHandler):
def __init__(self, inner: Any):
def __init__(self, inner: OpsHandler[Any]):
self._inner = inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
@ -1074,11 +1023,11 @@ class OpCountResult(NamedTuple):
class OpCounterCSE(DefaultHandler):
"""Shim to count how many ops are used"""
def __init__(self, inner):
def __init__(self, inner: OpsHandler[Any]):
super().__init__()
self.parent_handler = inner
self.op_count = 0
self.var_names = {}
self.var_names: dict[str, str] = {}
self._used_ops: OrderedSet[str] = OrderedSet()
self._read_names: list[str] = []
self._nontrivial_read_count = 0
@ -1152,26 +1101,16 @@ class OpCounterCSE(DefaultHandler):
)
if TYPE_CHECKING:
class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class ExtractConstantsHandler(NoopHandler):
def __init__(self, device):
def __init__(self, device: Optional[torch.device]):
self.device = device
def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
from torch._inductor import ir
return ir.Constant(value=value, dtype=dtype, device=self.device)
if TYPE_CHECKING:
class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
return ir.Constant(
value=value, dtype=dtype, device=self.device or torch.get_default_device()
)
class SimpleCSEHandler(WrapperHandler):
@ -1204,9 +1143,3 @@ class SimpleCSEHandler(WrapperHandler):
val = getattr(self._inner, name)(*args, **kwargs)
self.cse_cache[key] = val
return val
if TYPE_CHECKING:
class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong

View File

@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode):
return artifact_path
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
return h
@dataclasses.dataclass
class CompiledAOTI(OutputCode):
"""
@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode):
pass
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
return h
@dataclasses.dataclass
class MockFXGraphCacheOutput(OutputCode):
gm: Any = None

View File

@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler):
def placeholder(self, idx: int) -> torch.fx.Proxy:
return self.placeholders[idx]
def output(self, *args: tuple[object]) -> torch.fx.Node:
return self.tracer.create_node(
def output(self, *args: tuple[object]) -> None:
self.tracer.create_node(
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
)

View File

@ -306,9 +306,7 @@ class OpsWrapper(DefaultHandler):
return _ops.indirect_indexing(index, size, check, wrap_neg)
# we lie about the type of ops so the rest of the codebase typecheck properly
# DefaultHandler implements the OpsHandler protocol via metaprogramming
ops = cast(OpsHandler[Any], OpsWrapper())
ops: OpsHandler[Any] = OpsWrapper()
class _V:
@ -316,8 +314,10 @@ class _V:
KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
get_ops_handler: Callable[[], Any] = _ops._get_handler
set_ops_handler: Callable[
[OpsHandler[Any]], AbstractContextManager[None]
] = _ops._set_handler
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler

View File

@ -49,7 +49,7 @@ from .numbers import int_oo, IntInfinity, NegativeIntInfinity
log = logging.getLogger(__name__)
__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
__all__ = ["ValueRanges", "bound_sympy"]
_T = TypeVar("_T", sympy.Expr, SympyBoolean)
@ -1004,108 +1004,6 @@ class SymPyValueRangeAnalysis:
return ValueRanges.increasing_map(x, TruncToFloat)
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self) -> None:
self.name = "ValueRangeAnalysis"
boolean_operators = (
"xor",
"logical_and",
"logical_or",
"logical_not",
)
for op in boolean_operators:
setattr(self, op, self.bool_handler)
@staticmethod
def bool_handler(*args, **kwargs):
# just assuming bools can have both values
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
@staticmethod
def default_handler(*args, **kwargs):
# many ops are unlikely to show up in optimizable indexing compute,
# so we dont have full coverage
return ValueRanges.unknown()
def load(self, name: str, index: sympy.Expr):
return ValueRanges.unknown()
def store(self, name, index, value, mode=None):
return
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
return ValueRanges.unknown()
@classmethod
def index_expr(cls, index, dtype):
assert isinstance(index, ValueRanges)
return cls.to_dtype(index, dtype)
@staticmethod
def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
x = ValueRanges.wrap(x)
if dtype == torch.bool:
if x.is_singleton():
return ValueRanges.wrap(x.lower != 0)
elif x.is_bool:
return x
elif 0 not in x:
return ValueRanges.wrap(sympy.true)
else:
return ValueRanges(sympy.false, sympy.true)
def cast(x, dtype):
# dtype is int or float
if dtype.is_floating_point:
return sympy.Float(x)
else:
if x in (int_oo, -int_oo):
return x
try:
return sympy.Integer(x)
except TypeError:
# inf cannot be cast to Integer
return x
if x.is_bool:
if x.is_singleton():
val = 1 if x.lower else 0
return ValueRanges.wrap(cast(val, dtype))
else:
return ValueRanges(cast(0, dtype), cast(1, dtype))
else:
# int to float or float to int
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
@staticmethod
def square(x):
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
@staticmethod
def neg(x):
return ValueRanges.decreasing_map(x, operator.neg)
# TODO: this is slightly inaccurate because truncdiv operates at integer
# precision, but we're going through float truediv which means we can
# potentially lose precision on the bounds
@classmethod
def truncdiv(cls, a, b):
x = cls.truediv(a, b)
if x == ValueRanges.unknown():
return x
return cls.trunc(x)
@classmethod
def sub(cls, a, b):
return cls.add(a, cls.neg(b))
def __getattr__(self, name):
log.debug("unhandled ValueRange op %s", name)
return self.default_handler
def bound_sympy(
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges: