mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] Refactor op handlers part 5 (#146257)
This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254, #146255
This commit is contained in:
committed by
PyTorch MergeBot
parent
403db2faee
commit
06604c4ec1
@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
|
||||
from torch._inductor.codegen.halide import HalideOverrides
|
||||
from torch._inductor.codegen.mps import MetalOverrides
|
||||
from torch._inductor.codegen.triton import TritonKernelOverrides
|
||||
from torch._inductor.ops_handler import list_ops, OP_NAMES
|
||||
from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler
|
||||
from torch._inductor.test_case import TestCase
|
||||
|
||||
|
||||
class TestOpCompleteness(TestCase):
|
||||
def verify_ops_handler_completeness(self, handler):
|
||||
op_names = list_ops(handler)
|
||||
if OP_NAMES == op_names:
|
||||
return
|
||||
print(f"Missing ops: {OP_NAMES - op_names}")
|
||||
print(f"Extra ops: {op_names - OP_NAMES}")
|
||||
self.assertEqual(", ".join(OP_NAMES - op_names), "")
|
||||
self.assertEqual(", ".join(op_names - OP_NAMES), "")
|
||||
for op in OP_NAMES:
|
||||
self.assertIsNot(
|
||||
getattr(handler, op),
|
||||
getattr(OpsHandler, op),
|
||||
msg=f"{handler} must implement {op}",
|
||||
)
|
||||
extra_ops = list_ops(handler) - OP_NAMES
|
||||
if extra_ops:
|
||||
raise AssertionError(
|
||||
f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`"
|
||||
)
|
||||
|
||||
def test_triton_overrides(self):
|
||||
self.verify_ops_handler_completeness(TritonKernelOverrides)
|
||||
|
@ -34,7 +34,8 @@ from torch.utils._sympy.reference import (
|
||||
)
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch._inductor.bounds import ValueRangeAnalysis
|
||||
|
||||
|
||||
UNARY_OPS = [
|
||||
|
@ -9,7 +9,7 @@ from torch._inductor import config
|
||||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
||||
from torch._inductor.index_propagation import SymPyOps, TypedExpr
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .ops_handler import DefaultHandler
|
||||
from .virtualized import StoreMode, V
|
||||
|
||||
|
||||
@ -32,27 +32,22 @@ class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
self.store_preserves_zeros: Optional[bool] = None
|
||||
self.dtype_prop = DtypePropagationOpsHandler()
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr) -> TypedExpr:
|
||||
def load(self, name: str, index: sympy.Expr) -> TypedExpr:
|
||||
# In prologue fusion, all loads get broadcasted
|
||||
dtype = V.get_ops_handler().dtype_prop.load(name, index)
|
||||
dtype = self.dtype_prop.load(name, index)
|
||||
return TypedExpr(
|
||||
sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def store(
|
||||
name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
|
||||
self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None
|
||||
) -> None:
|
||||
self = V.get_ops_handler()
|
||||
assert isinstance(self, PreservesZeros)
|
||||
# should only have a single store in prologue
|
||||
assert self.store_preserves_zeros is None
|
||||
self.store_preserves_zeros = value.is_constant() and value.expr == 0
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
self = V.get_ops_handler()
|
||||
def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
return construct_symbol(next(self.count), torch.int32)
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
@ -65,12 +60,6 @@ class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
|
||||
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
|
||||
"""
|
||||
Does this prologue preserve zero masks
|
||||
@ -100,9 +89,8 @@ class RecordLowPrecisionOps(DefaultHandler):
|
||||
"constant",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr) -> DTypeContainer:
|
||||
return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index))
|
||||
def load(self, name: str, index: sympy.Expr) -> DTypeContainer:
|
||||
return DTypeContainer(self.dtype_prop.load(name, index))
|
||||
|
||||
@staticmethod
|
||||
def store(
|
||||
@ -131,12 +119,6 @@ class RecordLowPrecisionOps(DefaultHandler):
|
||||
return out
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
|
||||
|
@ -1,14 +1,22 @@
|
||||
import logging
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import (
|
||||
bound_sympy,
|
||||
SymPyValueRangeAnalysis,
|
||||
ValueRanges,
|
||||
)
|
||||
|
||||
from ..utils._sympy.functions import PowByNatural
|
||||
from ..utils._sympy.numbers import int_oo
|
||||
from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
|
||||
from .ops_handler import DefaultHandler, ReductionType, StoreMode
|
||||
from .utils import cache_on_self, dominated_nodes
|
||||
from .virtualized import V
|
||||
|
||||
@ -139,3 +147,113 @@ class BoundVars:
|
||||
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
|
||||
self.replacement_vals[name] = bound
|
||||
return bound
|
||||
|
||||
|
||||
class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
|
||||
def __init__(self) -> None:
|
||||
self.name = "ValueRangeAnalysis"
|
||||
boolean_operators = (
|
||||
"xor",
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
"logical_not",
|
||||
)
|
||||
for op in boolean_operators:
|
||||
setattr(self, op, self.bool_handler)
|
||||
|
||||
@staticmethod
|
||||
def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
||||
# just assuming bools can have both values
|
||||
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
# many ops are unlikely to show up in optimizable indexing compute,
|
||||
# so we dont have full coverage
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: torch.dtype,
|
||||
reduction_type: ReductionType,
|
||||
value: Any,
|
||||
) -> ValueRanges[Any]:
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@classmethod
|
||||
def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
|
||||
assert isinstance(index, ValueRanges)
|
||||
return cls.to_dtype(index, dtype)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: Any,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: Optional[torch.dtype] = None,
|
||||
use_compute_types: bool = True,
|
||||
) -> ValueRanges[Any]:
|
||||
x = ValueRanges.wrap(x)
|
||||
|
||||
if dtype == torch.bool:
|
||||
if x.is_singleton():
|
||||
return ValueRanges.wrap(x.lower != 0)
|
||||
elif x.is_bool:
|
||||
return x
|
||||
elif 0 not in x:
|
||||
return ValueRanges.wrap(sympy.true)
|
||||
else:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
||||
def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
|
||||
# dtype is int or float
|
||||
if dtype.is_floating_point:
|
||||
return sympy.Float(x)
|
||||
else:
|
||||
if x in (int_oo, -int_oo):
|
||||
return x
|
||||
try:
|
||||
return sympy.Integer(x)
|
||||
except TypeError:
|
||||
# inf cannot be cast to Integer
|
||||
return x
|
||||
|
||||
if x.is_bool:
|
||||
if x.is_singleton():
|
||||
val = 1 if x.lower else 0
|
||||
return ValueRanges.wrap(cast(val, dtype))
|
||||
else:
|
||||
return ValueRanges(cast(0, dtype), cast(1, dtype))
|
||||
else:
|
||||
# int to float or float to int
|
||||
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
|
||||
|
||||
@staticmethod
|
||||
def square(x: Any) -> ValueRanges[Any]:
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
|
||||
|
||||
@staticmethod
|
||||
def neg(x: Any) -> ValueRanges[Any]:
|
||||
return ValueRanges.decreasing_map(x, operator.neg)
|
||||
|
||||
# TODO: this is slightly inaccurate because truncdiv operates at integer
|
||||
# precision, but we're going through float truediv which means we can
|
||||
# potentially lose precision on the bounds
|
||||
@classmethod
|
||||
def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
|
||||
x = cls.truediv(a, b)
|
||||
if x == ValueRanges.unknown():
|
||||
return x
|
||||
|
||||
return cls.trunc(x)
|
||||
|
||||
@classmethod
|
||||
def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
|
||||
return cls.add(a, cls.neg(b))
|
||||
|
@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..ops_handler import BasicMathOps, DefaultHandler
|
||||
from ..ops_handler import BasicMathOpsMixin, DefaultHandler
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OpOverrides(BasicMathOps, OpDecompositions):
|
||||
class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||
@staticmethod
|
||||
def paren(string: OpVarT) -> OpVarT:
|
||||
if (
|
||||
@ -1235,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class DeferredLine(DeferredLineBase):
|
||||
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
||||
|
||||
@ -2268,6 +2262,8 @@ class CSEProxy(DefaultHandler):
|
||||
|
||||
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
|
||||
super().__init__()
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
|
||||
self.vr_analysis = ValueRangeAnalysis()
|
||||
self.kernel = kernel
|
||||
self.parent_handler = parent_handler
|
||||
@ -2338,6 +2334,7 @@ class CSEProxy(DefaultHandler):
|
||||
If the variable comes from an FX node, we forward the bound we have already computed
|
||||
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
||||
"""
|
||||
from ..bounds import ValueRangeAnalysis
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
|
||||
if isinstance(V.kernel, TritonTemplateKernel):
|
||||
@ -2575,9 +2572,3 @@ class CSEProxy(DefaultHandler):
|
||||
sorter,
|
||||
sorter_indices,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]):
|
||||
pass
|
||||
|
@ -33,7 +33,7 @@ from ..utils import (
|
||||
sympy_index_symbol,
|
||||
sympy_subs,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, V
|
||||
from ..virtualized import _ops as ops, V
|
||||
from .common import (
|
||||
BackendFeature,
|
||||
CSEVariable,
|
||||
@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides):
|
||||
HalideOverrides._initialize_pointwise_overrides("halide")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HalideCSEVariable(CSEVariable):
|
||||
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
|
||||
|
||||
|
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
|
||||
import sympy
|
||||
|
||||
from ..ops_handler import OpsHandler, ReductionType, StoreMode
|
||||
from ..ops_handler import ReductionType, StoreMode
|
||||
from ..scheduler import Scheduler, SchedulerNode
|
||||
from .common import OpVarT
|
||||
|
||||
@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides):
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
overrides = MetalOverrides # type: ignore[assignment]
|
||||
suffix = ";"
|
||||
|
@ -61,7 +61,7 @@ from ..utils import (
|
||||
triton_version_uses_attrs_dict,
|
||||
upcast_compute_type,
|
||||
)
|
||||
from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
|
||||
from ..virtualized import _ops as ops, ReductionType, StoreMode, V
|
||||
from ..wrapper_benchmark import get_kernel_category_by_source_code
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import (
|
||||
@ -1428,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
return (mantissa, exponent)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
"""An ordered set of helper functions."""
|
||||
|
||||
|
@ -4,17 +4,7 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -34,7 +24,7 @@ from .utils import (
|
||||
sympy_subs,
|
||||
VarRanges,
|
||||
)
|
||||
from .virtualized import OpsHandler, ReductionType, V
|
||||
from .virtualized import ReductionType, V
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -799,14 +789,6 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
||||
body()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def extract_free_unbacked_symbols(
|
||||
fn: Callable[..., Any],
|
||||
index: Sequence[sympy.Expr],
|
||||
|
@ -23,7 +23,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing import Any, Literal, Optional, overload, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import sympy
|
||||
@ -33,7 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .ops_handler import DefaultHandler
|
||||
from .sizevars import evaluate_expr
|
||||
from .utils import generate_assert
|
||||
from .virtualized import V
|
||||
@ -370,9 +370,3 @@ class IndexPropagation(DefaultHandler):
|
||||
"indirect_indexing", (index, size, check, wrap_neg), {}
|
||||
).value
|
||||
return indirect_var
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
|
||||
pass
|
||||
|
@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT
|
||||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
@ -655,7 +655,7 @@ class LoopBodyBlock:
|
||||
|
||||
|
||||
class CountOps(DefaultHandler):
|
||||
def __init__(self, inner: Any, counts: collections.Counter[str]):
|
||||
def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]):
|
||||
self._inner = inner
|
||||
self._counts = counts
|
||||
|
||||
|
@ -4,17 +4,7 @@ from __future__ import annotations
|
||||
import itertools
|
||||
import re
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -48,12 +38,8 @@ def _arg_str(a: object) -> str:
|
||||
return str(a)
|
||||
|
||||
|
||||
# NB: This is not done as a parent class, because our ops handlers
|
||||
# implementations make heavy use of __getattr__ magic, and pre-existing
|
||||
# stubs for methods would interfere with this mechanism.
|
||||
#
|
||||
# See OpDecompositions for superclass that desugars operations like reciprocal/square.
|
||||
class OpsHandler(Protocol[T]):
|
||||
class OpsHandler(Generic[T]):
|
||||
"""
|
||||
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
||||
as well as the contract for op handlers. The type T signifies the domain
|
||||
@ -77,49 +63,30 @@ class OpsHandler(Protocol[T]):
|
||||
ops handlers.
|
||||
|
||||
Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides),
|
||||
which means you will get type errors if you subclass OpsHandler since mypy doesn't know
|
||||
about the methods added via metaprogramming and thinks the class is still abstract.
|
||||
Instead, you should add a block like:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
Which will check the signatures of non-meta-programmed methods and gives decent error messages.
|
||||
|
||||
Some older parts of the code use a pattern like:
|
||||
|
||||
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
|
||||
return h
|
||||
|
||||
This pattern only works if the class defines a __getattr__ method, which we are moving away from.
|
||||
Additionally, this pattern generates horrible error messages if the signatures are wrong.
|
||||
It gives zero information about what the problem is, which makes the pattern harmful.
|
||||
|
||||
Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all
|
||||
operators are implemented after all the metaprogramming has run.
|
||||
which means you will not get type errors for those methods. We have tests in
|
||||
test/inductor/test_op_completeness.py which check that all operators are implemented after
|
||||
all the metaprogramming has run.
|
||||
"""
|
||||
|
||||
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
|
||||
"""Produces a scalar constant of type dtype."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def load_seed(self, name: str, offset: T) -> T:
|
||||
"""Computes inductor_prims.lookup_seed."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def rand(self, seed: T, offset: T) -> T:
|
||||
"""Computes inductor_prims.random with mode="rand". offset has dtype int32."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def randn(self, seed: T, offset: T) -> T:
|
||||
"""Computes inductor_prims.random with mode="randn". offset has dtype int32."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
|
||||
"""Computes inductor_prims.randint. offset has dtype int32."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
|
||||
"""
|
||||
@ -133,13 +100,13 @@ class OpsHandler(Protocol[T]):
|
||||
Contrast this with ops.where, which can multiplex between two values
|
||||
that have been unconditionally computed.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def where(self, condition: T, input: T, other: T) -> T:
|
||||
"""
|
||||
Computes torch.where: when condition is true, return input; otherwise return other.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
@ -147,7 +114,7 @@ class OpsHandler(Protocol[T]):
|
||||
an indexing expression, thus the name; however, it can also be used in
|
||||
non-indexing situations.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def to_dtype(
|
||||
self,
|
||||
@ -160,7 +127,7 @@ class OpsHandler(Protocol[T]):
|
||||
Convert x to dtype. src_dtype can be optionally set to specify what the original
|
||||
dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def trunc_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
@ -174,38 +141,38 @@ class OpsHandler(Protocol[T]):
|
||||
int64 depending on if we've shown that all the indexing operations can
|
||||
be done in int32.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def ceil_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with ceiling semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def floor_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with ceiling semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def round_to_int(self, x: T, dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Convert x to dtype with round-to-even semantics. See also trunc_to_int.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
|
||||
"""
|
||||
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
|
||||
src_dtype must be the original type of x.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def identity(self, x: T) -> T:
|
||||
"""
|
||||
Returns x as is. This is used to trigger CSE.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# These operations are only available in a "kernel" context. Check
|
||||
@ -227,13 +194,13 @@ class OpsHandler(Protocol[T]):
|
||||
NB: This is typically mandatory to implement for any analysis, because you
|
||||
MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> T:
|
||||
"""
|
||||
Load from the memory location 'name', offset by some indexing expression 'index'.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def store(
|
||||
self,
|
||||
@ -246,7 +213,7 @@ class OpsHandler(Protocol[T]):
|
||||
Store 'value' to the memory location 'name' offset by 'expr'. If
|
||||
specified, 'mode' can require the store to be an atomic addition.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO: Better explain how the "collective" semantics of these ops;
|
||||
# remember that the input value is a scalar, you can't reduce on it in the
|
||||
@ -268,7 +235,7 @@ class OpsHandler(Protocol[T]):
|
||||
function returns multiple outputs; consult reduction_num_outputs to
|
||||
determine the amount in metaprogramming applications.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO: in practice, this seems to actually return None, but not returning
|
||||
# a T makes common __getattr__ idioms not type correctly. Figure out if
|
||||
@ -278,7 +245,7 @@ class OpsHandler(Protocol[T]):
|
||||
Store the fully accumulated result of 'reduction' to the memory
|
||||
location 'name' offset by 'expr'.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def scan(
|
||||
self,
|
||||
@ -290,7 +257,7 @@ class OpsHandler(Protocol[T]):
|
||||
Perform an associative scan on 'value'.
|
||||
"""
|
||||
# TODO: Improve the description with some pseudocode
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sort(
|
||||
self,
|
||||
@ -302,7 +269,7 @@ class OpsHandler(Protocol[T]):
|
||||
"""
|
||||
Sort values along the reduction dimension.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
@ -315,231 +282,231 @@ class OpsHandler(Protocol[T]):
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> T:
|
||||
# See [Note: Inductor bucketize op]
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# The following ops have semantics that correspond exactly to the torch
|
||||
# operation with the same corresponding name.
|
||||
|
||||
def abs(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def exp(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def exp2(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def expm1(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sqrt(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def relu(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def minimum(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def maximum(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def cos(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sin(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def lgamma(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def erf(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def cosh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sinh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def acos(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def acosh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def asin(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def asinh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def atan2(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def atan(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def atanh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def copysign(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def erfc(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def erfinv(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def frexp(self, x0: T):
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def hypot(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def log10(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def log2(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def nextafter(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def logical_and(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def logical_not(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def logical_or(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def logical_xor(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_and(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_not(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_or(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_xor(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_left_shift(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bitwise_right_shift(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def rsqrt(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def log1p(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def tan(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def tanh(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sigmoid(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def signbit(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def fmod(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def isinf(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def isnan(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
# This rounds half to even to break ties
|
||||
def round(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def floor(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sign(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def trunc(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def ceil(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def neg(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def reciprocal(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def eq(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def ne(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def lt(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def gt(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def le(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def ge(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def sub(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def mul(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# NB: this returns a float, like the torch operation
|
||||
def pow(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def and_(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def or_(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def xor(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# These are metaprogrammed by MockHandler._init_cls
|
||||
def lshift(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def rshift(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# These are "special" operators. These only exist if the target
|
||||
@ -547,124 +514,124 @@ class OpsHandler(Protocol[T]):
|
||||
# pointwise_overrides_data.
|
||||
|
||||
def airy_ai(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bessel_j0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bessel_j1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bessel_y0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def bessel_y1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def digamma(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def erfcx(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def fma(self, x: T, y: T, z: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def igamma(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def igammac(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def gammainc(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def gammaincc(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def i0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def i0e(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def i1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def i1e(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def log_ndtr(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def modified_bessel_i0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def modified_bessel_i1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def modified_bessel_k0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def modified_bessel_k1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def ndtr(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def ndtri(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def polygamma(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def scaled_modified_bessel_k0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def scaled_modified_bessel_k1(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def spherical_bessel_j0(self, x: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def zeta(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def chebyshev_polynomial_t(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def chebyshev_polynomial_u(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def chebyshev_polynomial_v(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def chebyshev_polynomial_w(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def legendre_polynomial_p(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def hermite_polynomial_h(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def hermite_polynomial_he(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def laguerre_polynomial_l(self, x: T, y: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# These operators are a bit special, because they are conventionally
|
||||
@ -675,42 +642,42 @@ class OpsHandler(Protocol[T]):
|
||||
"""C-style trunc division between integers only. Computes the true
|
||||
division of two numbers and rounds the result to zero.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def floordiv(self, x0: T, x1: T) -> T:
|
||||
"""Python-style floor division between integers only. Computes the
|
||||
true division of two numbers and floors the result. If you want
|
||||
floor division for floats, do regular truediv and floor the result.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def truediv(self, x0: T, x1: T) -> T:
|
||||
"""True division between floats. Integer inputs are NOT valid. To
|
||||
do Python-style (int, int) -> float division, use int_truediv"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def int_truediv(self, x0: T, x1: T) -> T:
|
||||
"""True division between integers. This is NOT the same as promoting
|
||||
to float and doing integer division, there is a bespoke algorithm for
|
||||
doing the division in higher precision than the above.
|
||||
"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def mod(self, x0: T, x1: T) -> T:
|
||||
"""C-style modulus, take sign from LHS (x0)."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def remainder(self, x0: T, x1: T) -> T:
|
||||
"""Python-style modulus, take sign from RHS (x1)."""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def square(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def check_bounds(
|
||||
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
) -> None:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# In CUDA, optimized implementations of other mathematical operations are
|
||||
@ -726,25 +693,25 @@ class OpsHandler(Protocol[T]):
|
||||
# for many analyses it's not conveniently available.)
|
||||
|
||||
def libdevice_abs(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_exp(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_sqrt(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_cos(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_sin(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_sigmoid(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def libdevice_log(self, x0: T) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
# halide-only
|
||||
def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T:
|
||||
@ -760,15 +727,15 @@ class OpsHandler(Protocol[T]):
|
||||
is_pure: bool = True,
|
||||
pack: int = 1,
|
||||
) -> T:
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def output(self, x0: T) -> None:
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
def placeholder(self, index: int) -> T:
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
...
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
|
||||
@ -781,7 +748,7 @@ def list_ops(cls: type[Any]):
|
||||
OP_NAMES = list_ops(OpsHandler)
|
||||
|
||||
|
||||
class DefaultHandler:
|
||||
class DefaultHandler(OpsHandler[Any]):
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Default implementation for all ops. Override in a subclass to
|
||||
@ -850,13 +817,7 @@ class NoopHandler(DefaultHandler):
|
||||
return sympy.S.Zero
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class BasicMathOps:
|
||||
class BasicMathOpsMixin:
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return f"{a} + {b}"
|
||||
@ -935,7 +896,7 @@ class BasicMathOps:
|
||||
return f"-{a}"
|
||||
|
||||
|
||||
class MockHandler(BasicMathOps, DefaultHandler):
|
||||
class MockHandler(BasicMathOpsMixin, DefaultHandler):
|
||||
name = "MockHandler"
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
@ -971,14 +932,8 @@ class MockHandler(BasicMathOps, DefaultHandler):
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_MockHandler(MockHandler, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class KernelFormatterHandler(DefaultHandler):
|
||||
def __init__(self, parent_handler):
|
||||
def __init__(self, parent_handler: OpsHandler[Any]):
|
||||
self.parent_handler = parent_handler
|
||||
self._output = IndentedBuffer(1)
|
||||
self.var_counter = itertools.count()
|
||||
@ -1042,14 +997,8 @@ class KernelFormatterHandler(DefaultHandler):
|
||||
return self._output.getvalue()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class WrapperHandler(DefaultHandler):
|
||||
def __init__(self, inner: Any):
|
||||
def __init__(self, inner: OpsHandler[Any]):
|
||||
self._inner = inner
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
@ -1074,11 +1023,11 @@ class OpCountResult(NamedTuple):
|
||||
class OpCounterCSE(DefaultHandler):
|
||||
"""Shim to count how many ops are used"""
|
||||
|
||||
def __init__(self, inner):
|
||||
def __init__(self, inner: OpsHandler[Any]):
|
||||
super().__init__()
|
||||
self.parent_handler = inner
|
||||
self.op_count = 0
|
||||
self.var_names = {}
|
||||
self.var_names: dict[str, str] = {}
|
||||
self._used_ops: OrderedSet[str] = OrderedSet()
|
||||
self._read_names: list[str] = []
|
||||
self._nontrivial_read_count = 0
|
||||
@ -1152,26 +1101,16 @@ class OpCounterCSE(DefaultHandler):
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class ExtractConstantsHandler(NoopHandler):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Optional[torch.device]):
|
||||
self.device = device
|
||||
|
||||
def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant:
|
||||
from torch._inductor import ir
|
||||
|
||||
return ir.Constant(value=value, dtype=dtype, device=self.device)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
return ir.Constant(
|
||||
value=value, dtype=dtype, device=self.device or torch.get_default_device()
|
||||
)
|
||||
|
||||
|
||||
class SimpleCSEHandler(WrapperHandler):
|
||||
@ -1204,9 +1143,3 @@ class SimpleCSEHandler(WrapperHandler):
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
self.cse_cache[key] = val
|
||||
return val
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode):
|
||||
return artifact_path
|
||||
|
||||
|
||||
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledAOTI(OutputCode):
|
||||
"""
|
||||
@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode):
|
||||
pass
|
||||
|
||||
|
||||
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockFXGraphCacheOutput(OutputCode):
|
||||
gm: Any = None
|
||||
|
@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler):
|
||||
def placeholder(self, idx: int) -> torch.fx.Proxy:
|
||||
return self.placeholders[idx]
|
||||
|
||||
def output(self, *args: tuple[object]) -> torch.fx.Node:
|
||||
return self.tracer.create_node(
|
||||
def output(self, *args: tuple[object]) -> None:
|
||||
self.tracer.create_node(
|
||||
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
|
||||
)
|
||||
|
||||
|
@ -306,9 +306,7 @@ class OpsWrapper(DefaultHandler):
|
||||
return _ops.indirect_indexing(index, size, check, wrap_neg)
|
||||
|
||||
|
||||
# we lie about the type of ops so the rest of the codebase typecheck properly
|
||||
# DefaultHandler implements the OpsHandler protocol via metaprogramming
|
||||
ops = cast(OpsHandler[Any], OpsWrapper())
|
||||
ops: OpsHandler[Any] = OpsWrapper()
|
||||
|
||||
|
||||
class _V:
|
||||
@ -316,8 +314,10 @@ class _V:
|
||||
KernelFormatterHandler = KernelFormatterHandler
|
||||
WrapperHandler = WrapperHandler
|
||||
|
||||
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
|
||||
get_ops_handler: Callable[[], Any] = _ops._get_handler
|
||||
set_ops_handler: Callable[
|
||||
[OpsHandler[Any]], AbstractContextManager[None]
|
||||
] = _ops._set_handler
|
||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
||||
|
@ -49,7 +49,7 @@ from .numbers import int_oo, IntInfinity, NegativeIntInfinity
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]
|
||||
__all__ = ["ValueRanges", "bound_sympy"]
|
||||
|
||||
_T = TypeVar("_T", sympy.Expr, SympyBoolean)
|
||||
|
||||
@ -1004,108 +1004,6 @@ class SymPyValueRangeAnalysis:
|
||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||
|
||||
|
||||
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
||||
def __init__(self) -> None:
|
||||
self.name = "ValueRangeAnalysis"
|
||||
boolean_operators = (
|
||||
"xor",
|
||||
"logical_and",
|
||||
"logical_or",
|
||||
"logical_not",
|
||||
)
|
||||
for op in boolean_operators:
|
||||
setattr(self, op, self.bool_handler)
|
||||
|
||||
@staticmethod
|
||||
def bool_handler(*args, **kwargs):
|
||||
# just assuming bools can have both values
|
||||
return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def default_handler(*args, **kwargs):
|
||||
# many ops are unlikely to show up in optimizable indexing compute,
|
||||
# so we dont have full coverage
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
return
|
||||
|
||||
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
@classmethod
|
||||
def index_expr(cls, index, dtype):
|
||||
assert isinstance(index, ValueRanges)
|
||||
return cls.to_dtype(index, dtype)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
|
||||
x = ValueRanges.wrap(x)
|
||||
|
||||
if dtype == torch.bool:
|
||||
if x.is_singleton():
|
||||
return ValueRanges.wrap(x.lower != 0)
|
||||
elif x.is_bool:
|
||||
return x
|
||||
elif 0 not in x:
|
||||
return ValueRanges.wrap(sympy.true)
|
||||
else:
|
||||
return ValueRanges(sympy.false, sympy.true)
|
||||
|
||||
def cast(x, dtype):
|
||||
# dtype is int or float
|
||||
if dtype.is_floating_point:
|
||||
return sympy.Float(x)
|
||||
else:
|
||||
if x in (int_oo, -int_oo):
|
||||
return x
|
||||
try:
|
||||
return sympy.Integer(x)
|
||||
except TypeError:
|
||||
# inf cannot be cast to Integer
|
||||
return x
|
||||
|
||||
if x.is_bool:
|
||||
if x.is_singleton():
|
||||
val = 1 if x.lower else 0
|
||||
return ValueRanges.wrap(cast(val, dtype))
|
||||
else:
|
||||
return ValueRanges(cast(0, dtype), cast(1, dtype))
|
||||
else:
|
||||
# int to float or float to int
|
||||
return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
|
||||
|
||||
@staticmethod
|
||||
def square(x):
|
||||
return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
|
||||
|
||||
@staticmethod
|
||||
def neg(x):
|
||||
return ValueRanges.decreasing_map(x, operator.neg)
|
||||
|
||||
# TODO: this is slightly inaccurate because truncdiv operates at integer
|
||||
# precision, but we're going through float truediv which means we can
|
||||
# potentially lose precision on the bounds
|
||||
@classmethod
|
||||
def truncdiv(cls, a, b):
|
||||
x = cls.truediv(a, b)
|
||||
if x == ValueRanges.unknown():
|
||||
return x
|
||||
|
||||
return cls.trunc(x)
|
||||
|
||||
@classmethod
|
||||
def sub(cls, a, b):
|
||||
return cls.add(a, cls.neg(b))
|
||||
|
||||
def __getattr__(self, name):
|
||||
log.debug("unhandled ValueRange op %s", name)
|
||||
return self.default_handler
|
||||
|
||||
|
||||
def bound_sympy(
|
||||
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
|
||||
) -> ValueRanges:
|
||||
|
Reference in New Issue
Block a user