diff --git a/test/inductor/test_op_completeness.py b/test/inductor/test_op_completeness.py index 04fac4870fd7..23d59a789418 100644 --- a/test/inductor/test_op_completeness.py +++ b/test/inductor/test_op_completeness.py @@ -5,19 +5,23 @@ from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides from torch._inductor.codegen.halide import HalideOverrides from torch._inductor.codegen.mps import MetalOverrides from torch._inductor.codegen.triton import TritonKernelOverrides -from torch._inductor.ops_handler import list_ops, OP_NAMES +from torch._inductor.ops_handler import list_ops, OP_NAMES, OpsHandler from torch._inductor.test_case import TestCase class TestOpCompleteness(TestCase): def verify_ops_handler_completeness(self, handler): - op_names = list_ops(handler) - if OP_NAMES == op_names: - return - print(f"Missing ops: {OP_NAMES - op_names}") - print(f"Extra ops: {op_names - OP_NAMES}") - self.assertEqual(", ".join(OP_NAMES - op_names), "") - self.assertEqual(", ".join(op_names - OP_NAMES), "") + for op in OP_NAMES: + self.assertIsNot( + getattr(handler, op), + getattr(OpsHandler, op), + msg=f"{handler} must implement {op}", + ) + extra_ops = list_ops(handler) - OP_NAMES + if extra_ops: + raise AssertionError( + f"{handler} has an extra ops: {extra_ops}, add them to OpHandler class or prefix with `_`" + ) def test_triton_overrides(self): self.verify_ops_handler_completeness(TritonKernelOverrides) diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 5cd930274173..dddb73c2851d 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -34,7 +34,8 @@ from torch.utils._sympy.reference import ( ) from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ValueRanges +from torch._inductor.bounds import ValueRangeAnalysis UNARY_OPS = [ diff --git a/torch/_inductor/analyze_preserves_zero_mask.py b/torch/_inductor/analyze_preserves_zero_mask.py index abdf1320bc2d..974960b9589f 100644 --- a/torch/_inductor/analyze_preserves_zero_mask.py +++ b/torch/_inductor/analyze_preserves_zero_mask.py @@ -9,7 +9,7 @@ from torch._inductor import config from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from torch._inductor.index_propagation import SymPyOps, TypedExpr -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler from .virtualized import StoreMode, V @@ -32,27 +32,22 @@ class PreservesZeros(SymPyOps, DefaultHandler): self.store_preserves_zeros: Optional[bool] = None self.dtype_prop = DtypePropagationOpsHandler() - @staticmethod - def load(name: str, index: sympy.Expr) -> TypedExpr: + def load(self, name: str, index: sympy.Expr) -> TypedExpr: # In prologue fusion, all loads get broadcasted - dtype = V.get_ops_handler().dtype_prop.load(name, index) + dtype = self.dtype_prop.load(name, index) return TypedExpr( sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype ) - @staticmethod def store( - name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None ) -> None: - self = V.get_ops_handler() assert isinstance(self, PreservesZeros) # should only have a single store in prologue assert self.store_preserves_zeros is None self.store_preserves_zeros = value.is_constant() and value.expr == 0 - @staticmethod - def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: - self = V.get_ops_handler() + def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr: return construct_symbol(next(self.count), torch.int32) def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -65,12 +60,6 @@ class PreservesZeros(SymPyOps, DefaultHandler): return TypedExpr(construct_symbol(next(self.count), dtype), dtype) -if TYPE_CHECKING: - - class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]): - pass - - def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: """ Does this prologue preserve zero masks @@ -100,9 +89,8 @@ class RecordLowPrecisionOps(DefaultHandler): "constant", ) - @staticmethod - def load(name: str, index: sympy.Expr) -> DTypeContainer: - return DTypeContainer(V.get_ops_handler().dtype_prop.load(name, index)) + def load(self, name: str, index: sympy.Expr) -> DTypeContainer: + return DTypeContainer(self.dtype_prop.load(name, index)) @staticmethod def store( @@ -131,12 +119,6 @@ class RecordLowPrecisionOps(DefaultHandler): return out -if TYPE_CHECKING: - - class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]): - pass - - def low_prec_float(dtype: torch.dtype) -> bool: return dtype.is_floating_point and dtype.itemsize < 4 diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 3df87ada0dda..69c331646f81 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,14 +1,22 @@ import logging import operator from functools import partial -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union +import sympy from sympy import Expr import torch -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRanges, +) +from ..utils._sympy.functions import PowByNatural +from ..utils._sympy.numbers import int_oo from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .ops_handler import DefaultHandler, ReductionType, StoreMode from .utils import cache_on_self, dominated_nodes from .virtualized import V @@ -139,3 +147,113 @@ class BoundVars: # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) self.replacement_vals[name] = bound return bound + + +class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler): + def __init__(self) -> None: + self.name = "ValueRangeAnalysis" + boolean_operators = ( + "xor", + "logical_and", + "logical_or", + "logical_not", + ) + for op in boolean_operators: + setattr(self, op, self.bool_handler) + + @staticmethod + def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]: + # just assuming bools can have both values + return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # many ops are unlikely to show up in optimizable indexing compute, + # so we dont have full coverage + return ValueRanges.unknown() + + def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]: + return ValueRanges.unknown() + + def store( + self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None + ) -> None: + return + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Any, + ) -> ValueRanges[Any]: + return ValueRanges.unknown() + + @classmethod + def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: + assert isinstance(index, ValueRanges) + return cls.to_dtype(index, dtype) + + @staticmethod + def to_dtype( + x: Any, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ValueRanges[Any]: + x = ValueRanges.wrap(x) + + if dtype == torch.bool: + if x.is_singleton(): + return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x + elif 0 not in x: + return ValueRanges.wrap(sympy.true) + else: + return ValueRanges(sympy.false, sympy.true) + + def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: + # dtype is int or float + if dtype.is_floating_point: + return sympy.Float(x) + else: + if x in (int_oo, -int_oo): + return x + try: + return sympy.Integer(x) + except TypeError: + # inf cannot be cast to Integer + return x + + if x.is_bool: + if x.is_singleton(): + val = 1 if x.lower else 0 + return ValueRanges.wrap(cast(val, dtype)) + else: + return ValueRanges(cast(0, dtype), cast(1, dtype)) + else: + # int to float or float to int + return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) + + @staticmethod + def square(x: Any) -> ValueRanges[Any]: + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + + @staticmethod + def neg(x: Any) -> ValueRanges[Any]: + return ValueRanges.decreasing_map(x, operator.neg) + + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds + @classmethod + def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]: + x = cls.truediv(a, b) + if x == ValueRanges.unknown(): + return x + + return cls.trunc(x) + + @classmethod + def sub(cls, a: Any, b: Any) -> ValueRanges[Any]: + return cls.add(a, cls.neg(b)) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index ab5280e1d5b0..dbd02188665a 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -37,11 +37,11 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT -from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler -from ..ops_handler import BasicMathOps, DefaultHandler +from ..ops_handler import BasicMathOpsMixin, DefaultHandler from ..utils import ( boolean_ops, DeferredLineBase, @@ -764,7 +764,7 @@ def _all_in_parens(string: str) -> bool: return True -class OpOverrides(BasicMathOps, OpDecompositions): +class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): @staticmethod def paren(string: OpVarT) -> OpVarT: if ( @@ -1235,12 +1235,6 @@ pointwise_overrides_data: dict[str, OverridesData] = dict( ) -if TYPE_CHECKING: - - class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" @@ -2268,6 +2262,8 @@ class CSEProxy(DefaultHandler): def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): super().__init__() + from ..bounds import ValueRangeAnalysis + self.vr_analysis = ValueRangeAnalysis() self.kernel = kernel self.parent_handler = parent_handler @@ -2338,6 +2334,7 @@ class CSEProxy(DefaultHandler): If the variable comes from an FX node, we forward the bound we have already computed Else, if the variable when codegen'ing another op, we try to compute its bounds """ + from ..bounds import ValueRangeAnalysis from ..select_algorithm import TritonTemplateKernel if isinstance(V.kernel, TritonTemplateKernel): @@ -2575,9 +2572,3 @@ class CSEProxy(DefaultHandler): sorter, sorter_indices, ) - - -if TYPE_CHECKING: - - class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]): - pass diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 560e75c648f3..f2bdebf3c1b8 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -33,7 +33,7 @@ from ..utils import ( sympy_index_symbol, sympy_subs, ) -from ..virtualized import _ops as ops, OpsHandler, V +from ..virtualized import _ops as ops, V from .common import ( BackendFeature, CSEVariable, @@ -563,12 +563,6 @@ class HalideOverrides(OpOverrides): HalideOverrides._initialize_pointwise_overrides("halide") -if TYPE_CHECKING: - - class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HalideCSEVariable(CSEVariable): undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 86cbb6f5361d..dd3ff699e8a5 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: import sympy - from ..ops_handler import OpsHandler, ReductionType, StoreMode + from ..ops_handler import ReductionType, StoreMode from ..scheduler import Scheduler, SchedulerNode from .common import OpVarT @@ -367,12 +367,6 @@ class MetalOverrides(OpOverrides): MetalOverrides._initialize_pointwise_overrides("mps") -if TYPE_CHECKING: - - class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong - - class MetalKernel(SIMDKernel): overrides = MetalOverrides # type: ignore[assignment] suffix = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 7b123f590fa4..e0c1f9884799 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -61,7 +61,7 @@ from ..utils import ( triton_version_uses_attrs_dict, upcast_compute_type, ) -from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V +from ..virtualized import _ops as ops, ReductionType, StoreMode, V from ..wrapper_benchmark import get_kernel_category_by_source_code from .block_analysis import BlockPatternMatcher from .common import ( @@ -1428,12 +1428,6 @@ class TritonKernelOverrides(TritonOverrides): return (mantissa, exponent) -if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class HelperFunctions: """An ordered set of helper functions.""" diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 820e737414d0..36000a50cb84 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -4,17 +4,7 @@ import itertools import logging import re from collections.abc import Sequence -from typing import ( - Any, - Callable, - Iterable, - List, - Optional, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union from unittest.mock import patch import sympy @@ -34,7 +24,7 @@ from .utils import ( sympy_subs, VarRanges, ) -from .virtualized import OpsHandler, ReductionType, V +from .virtualized import ReductionType, V T = TypeVar("T") @@ -799,14 +789,6 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler): body() -if TYPE_CHECKING: - - class _typecheck_FreeUnbackedSymbolsOpsHandler( - FreeUnbackedSymbolsOpsHandler, OpsHandler[None] - ): - pass - - def extract_free_unbacked_symbols( fn: Callable[..., Any], index: Sequence[sympy.Expr], diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 741864e41e1a..2e5640413408 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -23,7 +23,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing. import itertools from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, overload, Union from typing_extensions import TypeAlias import sympy @@ -33,7 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from .ops_handler import DefaultHandler, OpsHandler +from .ops_handler import DefaultHandler from .sizevars import evaluate_expr from .utils import generate_assert from .virtualized import V @@ -370,9 +370,3 @@ class IndexPropagation(DefaultHandler): "indirect_indexing", (index, size, check, wrap_neg), {} ).value return indirect_var - - -if TYPE_CHECKING: - - class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]): - pass diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index feb88a09e750..afee89882536 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,7 +17,7 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering -from .ops_handler import DefaultHandler +from .ops_handler import DefaultHandler, OpsHandler from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V @@ -655,7 +655,7 @@ class LoopBodyBlock: class CountOps(DefaultHandler): - def __init__(self, inner: Any, counts: collections.Counter[str]): + def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]): self._inner = inner self._counts = counts diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 22ce7154c6b2..5338372f6afb 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -4,17 +4,7 @@ from __future__ import annotations import itertools import re import warnings -from typing import ( - Any, - Callable, - Literal, - NamedTuple, - Optional, - TYPE_CHECKING, - TypeVar, - Union, -) -from typing_extensions import Protocol +from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from unittest.mock import patch import sympy @@ -48,12 +38,8 @@ def _arg_str(a: object) -> str: return str(a) -# NB: This is not done as a parent class, because our ops handlers -# implementations make heavy use of __getattr__ magic, and pre-existing -# stubs for methods would interfere with this mechanism. -# # See OpDecompositions for superclass that desugars operations like reciprocal/square. -class OpsHandler(Protocol[T]): +class OpsHandler(Generic[T]): """ Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, as well as the contract for op handlers. The type T signifies the domain @@ -77,49 +63,30 @@ class OpsHandler(Protocol[T]): ops handlers. Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), - which means you will get type errors if you subclass OpsHandler since mypy doesn't know - about the methods added via metaprogramming and thinks the class is still abstract. - Instead, you should add a block like: - - if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - Which will check the signatures of non-meta-programmed methods and gives decent error messages. - - Some older parts of the code use a pattern like: - - def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h - - This pattern only works if the class defines a __getattr__ method, which we are moving away from. - Additionally, this pattern generates horrible error messages if the signatures are wrong. - It gives zero information about what the problem is, which makes the pattern harmful. - - Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all - operators are implemented after all the metaprogramming has run. + which means you will not get type errors for those methods. We have tests in + test/inductor/test_op_completeness.py which check that all operators are implemented after + all the metaprogramming has run. """ def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: """Produces a scalar constant of type dtype.""" - ... + raise NotImplementedError def load_seed(self, name: str, offset: T) -> T: """Computes inductor_prims.lookup_seed.""" - ... + raise NotImplementedError def rand(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" - ... + raise NotImplementedError def randn(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" - ... + raise NotImplementedError def randint64(self, seed: T, offset: T, low: T, high: T) -> T: """Computes inductor_prims.randint. offset has dtype int32.""" - ... + raise NotImplementedError def masked(self, mask: T, body: Callable[[], T], other: T) -> T: """ @@ -133,13 +100,13 @@ class OpsHandler(Protocol[T]): Contrast this with ops.where, which can multiplex between two values that have been unconditionally computed. """ - ... + raise NotImplementedError def where(self, condition: T, input: T, other: T) -> T: """ Computes torch.where: when condition is true, return input; otherwise return other. """ - ... + raise NotImplementedError def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: """ @@ -147,7 +114,7 @@ class OpsHandler(Protocol[T]): an indexing expression, thus the name; however, it can also be used in non-indexing situations. """ - ... + raise NotImplementedError def to_dtype( self, @@ -160,7 +127,7 @@ class OpsHandler(Protocol[T]): Convert x to dtype. src_dtype can be optionally set to specify what the original dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). """ - ... + raise NotImplementedError def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: """ @@ -174,38 +141,38 @@ class OpsHandler(Protocol[T]): int64 depending on if we've shown that all the indexing operations can be done in int32. """ - ... + raise NotImplementedError def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def floor_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with ceiling semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def round_to_int(self, x: T, dtype: torch.dtype) -> T: """ Convert x to dtype with round-to-even semantics. See also trunc_to_int. """ - ... + raise NotImplementedError def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) src_dtype must be the original type of x. """ - ... + raise NotImplementedError def identity(self, x: T) -> T: """ Returns x as is. This is used to trigger CSE. """ - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operations are only available in a "kernel" context. Check @@ -227,13 +194,13 @@ class OpsHandler(Protocol[T]): NB: This is typically mandatory to implement for any analysis, because you MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). """ - ... + raise NotImplementedError def load(self, name: str, index: sympy.Expr) -> T: """ Load from the memory location 'name', offset by some indexing expression 'index'. """ - ... + raise NotImplementedError def store( self, @@ -246,7 +213,7 @@ class OpsHandler(Protocol[T]): Store 'value' to the memory location 'name' offset by 'expr'. If specified, 'mode' can require the store to be an atomic addition. """ - ... + raise NotImplementedError # TODO: Better explain how the "collective" semantics of these ops; # remember that the input value is a scalar, you can't reduce on it in the @@ -268,7 +235,7 @@ class OpsHandler(Protocol[T]): function returns multiple outputs; consult reduction_num_outputs to determine the amount in metaprogramming applications. """ - ... + raise NotImplementedError # TODO: in practice, this seems to actually return None, but not returning # a T makes common __getattr__ idioms not type correctly. Figure out if @@ -278,7 +245,7 @@ class OpsHandler(Protocol[T]): Store the fully accumulated result of 'reduction' to the memory location 'name' offset by 'expr'. """ - ... + raise NotImplementedError def scan( self, @@ -290,7 +257,7 @@ class OpsHandler(Protocol[T]): Perform an associative scan on 'value'. """ # TODO: Improve the description with some pseudocode - ... + raise NotImplementedError def sort( self, @@ -302,7 +269,7 @@ class OpsHandler(Protocol[T]): """ Sort values along the reduction dimension. """ - ... + raise NotImplementedError def bucketize( self, @@ -315,231 +282,231 @@ class OpsHandler(Protocol[T]): sorter_indices: Optional[T] = None, ) -> T: # See [Note: Inductor bucketize op] - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # The following ops have semantics that correspond exactly to the torch # operation with the same corresponding name. def abs(self, x0: T) -> T: - ... + raise NotImplementedError def exp(self, x0: T) -> T: - ... + raise NotImplementedError def exp2(self, x0: T) -> T: - ... + raise NotImplementedError def expm1(self, x0: T) -> T: - ... + raise NotImplementedError def sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def relu(self, x0: T) -> T: - ... + raise NotImplementedError def minimum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def maximum(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def cos(self, x0: T) -> T: - ... + raise NotImplementedError def sin(self, x0: T) -> T: - ... + raise NotImplementedError def lgamma(self, x0: T) -> T: - ... + raise NotImplementedError def erf(self, x0: T) -> T: - ... + raise NotImplementedError def cosh(self, x0: T) -> T: - ... + raise NotImplementedError def sinh(self, x0: T) -> T: - ... + raise NotImplementedError def acos(self, x0: T) -> T: - ... + raise NotImplementedError def acosh(self, x0: T) -> T: - ... + raise NotImplementedError def asin(self, x0: T) -> T: - ... + raise NotImplementedError def asinh(self, x0: T) -> T: - ... + raise NotImplementedError def atan2(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def atan(self, x0: T) -> T: - ... + raise NotImplementedError def atanh(self, x0: T) -> T: - ... + raise NotImplementedError def copysign(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def erfc(self, x0: T) -> T: - ... + raise NotImplementedError def erfinv(self, x0: T) -> T: - ... + raise NotImplementedError def frexp(self, x0: T): - ... + raise NotImplementedError def hypot(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log10(self, x0: T) -> T: - ... + raise NotImplementedError def log2(self, x0: T) -> T: - ... + raise NotImplementedError def nextafter(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_not(self, x0: T) -> T: - ... + raise NotImplementedError def logical_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def logical_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_and(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_not(self, x0: T) -> T: - ... + raise NotImplementedError def bitwise_or(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_left_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def bitwise_right_shift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rsqrt(self, x0: T) -> T: - ... + raise NotImplementedError def log1p(self, x0: T) -> T: - ... + raise NotImplementedError def tan(self, x0: T) -> T: - ... + raise NotImplementedError def tanh(self, x0: T) -> T: - ... + raise NotImplementedError def sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def signbit(self, x0: T) -> T: - ... + raise NotImplementedError def fmod(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def log(self, x0: T) -> T: - ... + raise NotImplementedError def isinf(self, x0: T) -> T: - ... + raise NotImplementedError def isnan(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation # This rounds half to even to break ties def round(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: - ... + raise NotImplementedError def sign(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: - ... + raise NotImplementedError def neg(self, x0: T) -> T: - ... + raise NotImplementedError def reciprocal(self, x0: T) -> T: - ... + raise NotImplementedError def eq(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ne(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def lt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def gt(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def le(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def ge(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def add(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def sub(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def mul(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def and_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def or_(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def xor(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # These are metaprogrammed by MockHandler._init_cls def lshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError def rshift(self, x0: T, x1: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These are "special" operators. These only exist if the target @@ -547,124 +514,124 @@ class OpsHandler(Protocol[T]): # pointwise_overrides_data. def airy_ai(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_j1(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y0(self, x: T) -> T: - ... + raise NotImplementedError def bessel_y1(self, x: T) -> T: - ... + raise NotImplementedError def digamma(self, x: T) -> T: - ... + raise NotImplementedError def erfcx(self, x: T) -> T: - ... + raise NotImplementedError def fma(self, x: T, y: T, z: T) -> T: - ... + raise NotImplementedError def igamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def igammac(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammainc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def gammaincc(self, x: T, y: T) -> T: - ... + raise NotImplementedError def i0(self, x: T) -> T: - ... + raise NotImplementedError def i0e(self, x: T) -> T: - ... + raise NotImplementedError def i1(self, x: T) -> T: - ... + raise NotImplementedError def i1e(self, x: T) -> T: - ... + raise NotImplementedError def log_ndtr(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_i1(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def ndtr(self, x: T) -> T: - ... + raise NotImplementedError def ndtri(self, x: T) -> T: - ... + raise NotImplementedError def polygamma(self, x: T, y: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k0(self, x: T) -> T: - ... + raise NotImplementedError def scaled_modified_bessel_k1(self, x: T) -> T: - ... + raise NotImplementedError def spherical_bessel_j0(self, x: T) -> T: - ... + raise NotImplementedError def zeta(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def legendre_polynomial_p(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: - ... + raise NotImplementedError def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_h(self, x: T, y: T) -> T: - ... + raise NotImplementedError def hermite_polynomial_he(self, x: T, y: T) -> T: - ... + raise NotImplementedError def laguerre_polynomial_l(self, x: T, y: T) -> T: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These operators are a bit special, because they are conventionally @@ -675,42 +642,42 @@ class OpsHandler(Protocol[T]): """C-style trunc division between integers only. Computes the true division of two numbers and rounds the result to zero. """ - ... + raise NotImplementedError def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the true division of two numbers and floors the result. If you want floor division for floats, do regular truediv and floor the result. """ - ... + raise NotImplementedError def truediv(self, x0: T, x1: T) -> T: """True division between floats. Integer inputs are NOT valid. To do Python-style (int, int) -> float division, use int_truediv""" - ... + raise NotImplementedError def int_truediv(self, x0: T, x1: T) -> T: """True division between integers. This is NOT the same as promoting to float and doing integer division, there is a bespoke algorithm for doing the division in higher precision than the above. """ - ... + raise NotImplementedError def mod(self, x0: T, x1: T) -> T: """C-style modulus, take sign from LHS (x0).""" - ... + raise NotImplementedError def remainder(self, x0: T, x1: T) -> T: """Python-style modulus, take sign from RHS (x1).""" - ... + raise NotImplementedError def square(self, x0: T) -> T: - ... + raise NotImplementedError def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool ) -> None: - ... + raise NotImplementedError # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are @@ -726,25 +693,25 @@ class OpsHandler(Protocol[T]): # for many analyses it's not conveniently available.) def libdevice_abs(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_exp(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sqrt(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_cos(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sin(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_sigmoid(self, x0: T) -> T: - ... + raise NotImplementedError def libdevice_log(self, x0: T) -> T: - ... + raise NotImplementedError # halide-only def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: @@ -760,15 +727,15 @@ class OpsHandler(Protocol[T]): is_pure: bool = True, pack: int = 1, ) -> T: - ... + raise NotImplementedError def output(self, x0: T) -> None: """This is a fake op used in analysis but not codegen""" - ... + raise NotImplementedError def placeholder(self, index: int) -> T: """This is a fake op used in analysis but not codegen""" - ... + raise NotImplementedError _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -781,7 +748,7 @@ def list_ops(cls: type[Any]): OP_NAMES = list_ops(OpsHandler) -class DefaultHandler: +class DefaultHandler(OpsHandler[Any]): def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: """ Default implementation for all ops. Override in a subclass to @@ -850,13 +817,7 @@ class NoopHandler(DefaultHandler): return sympy.S.Zero -if TYPE_CHECKING: - - class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]): - pass # mypy will error if we got any of the signatures wrong - - -class BasicMathOps: +class BasicMathOpsMixin: @staticmethod def add(a, b): return f"{a} + {b}" @@ -935,7 +896,7 @@ class BasicMathOps: return f"-{a}" -class MockHandler(BasicMathOps, DefaultHandler): +class MockHandler(BasicMathOpsMixin, DefaultHandler): name = "MockHandler" def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -971,14 +932,8 @@ class MockHandler(BasicMathOps, DefaultHandler): return sympy_index_symbol(str(index_var)) -if TYPE_CHECKING: - - class _typecheck_MockHandler(MockHandler, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class KernelFormatterHandler(DefaultHandler): - def __init__(self, parent_handler): + def __init__(self, parent_handler: OpsHandler[Any]): self.parent_handler = parent_handler self._output = IndentedBuffer(1) self.var_counter = itertools.count() @@ -1042,14 +997,8 @@ class KernelFormatterHandler(DefaultHandler): return self._output.getvalue() -if TYPE_CHECKING: - - class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class WrapperHandler(DefaultHandler): - def __init__(self, inner: Any): + def __init__(self, inner: OpsHandler[Any]): self._inner = inner def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: @@ -1074,11 +1023,11 @@ class OpCountResult(NamedTuple): class OpCounterCSE(DefaultHandler): """Shim to count how many ops are used""" - def __init__(self, inner): + def __init__(self, inner: OpsHandler[Any]): super().__init__() self.parent_handler = inner self.op_count = 0 - self.var_names = {} + self.var_names: dict[str, str] = {} self._used_ops: OrderedSet[str] = OrderedSet() self._read_names: list[str] = [] self._nontrivial_read_count = 0 @@ -1152,26 +1101,16 @@ class OpCounterCSE(DefaultHandler): ) -if TYPE_CHECKING: - - class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - class ExtractConstantsHandler(NoopHandler): - def __init__(self, device): + def __init__(self, device: Optional[torch.device]): self.device = device def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant: from torch._inductor import ir - return ir.Constant(value=value, dtype=dtype, device=self.device) - - -if TYPE_CHECKING: - - class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong + return ir.Constant( + value=value, dtype=dtype, device=self.device or torch.get_default_device() + ) class SimpleCSEHandler(WrapperHandler): @@ -1204,9 +1143,3 @@ class SimpleCSEHandler(WrapperHandler): val = getattr(self._inner, name)(*args, **kwargs) self.cse_cache[key] = val return val - - -if TYPE_CHECKING: - - class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 393e282d03cc..66f60b12f16a 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -564,10 +564,6 @@ class CompiledFxGraph(OutputCode): return artifact_path -def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode: - return h - - @dataclasses.dataclass class CompiledAOTI(OutputCode): """ @@ -591,10 +587,6 @@ class CompiledAOTI(OutputCode): pass -def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode: - return h - - @dataclasses.dataclass class MockFXGraphCacheOutput(OutputCode): gm: Any = None diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 0166534e5fb5..d79923857359 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -149,8 +149,8 @@ class TracingOpsHandler(WrapperHandler): def placeholder(self, idx: int) -> torch.fx.Proxy: return self.placeholders[idx] - def output(self, *args: tuple[object]) -> torch.fx.Node: - return self.tracer.create_node( + def output(self, *args: tuple[object]) -> None: + self.tracer.create_node( "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} ) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index f82c84afdad1..1ee1ef5a7440 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -306,9 +306,7 @@ class OpsWrapper(DefaultHandler): return _ops.indirect_indexing(index, size, check, wrap_neg) -# we lie about the type of ops so the rest of the codebase typecheck properly -# DefaultHandler implements the OpsHandler protocol via metaprogramming -ops = cast(OpsHandler[Any], OpsWrapper()) +ops: OpsHandler[Any] = OpsWrapper() class _V: @@ -316,8 +314,10 @@ class _V: KernelFormatterHandler = KernelFormatterHandler WrapperHandler = WrapperHandler - set_ops_handler: Callable[[Any], Any] = _ops._set_handler - get_ops_handler: Callable[[], Any] = _ops._get_handler + set_ops_handler: Callable[ + [OpsHandler[Any]], AbstractContextManager[None] + ] = _ops._set_handler + get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler get_real_inputs: Callable[[], Any] = _real_inputs._get_handler diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index eb85b6798ea2..784f9e7ba051 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -49,7 +49,7 @@ from .numbers import int_oo, IntInfinity, NegativeIntInfinity log = logging.getLogger(__name__) -__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"] +__all__ = ["ValueRanges", "bound_sympy"] _T = TypeVar("_T", sympy.Expr, SympyBoolean) @@ -1004,108 +1004,6 @@ class SymPyValueRangeAnalysis: return ValueRanges.increasing_map(x, TruncToFloat) -class ValueRangeAnalysis(SymPyValueRangeAnalysis): - def __init__(self) -> None: - self.name = "ValueRangeAnalysis" - boolean_operators = ( - "xor", - "logical_and", - "logical_or", - "logical_not", - ) - for op in boolean_operators: - setattr(self, op, self.bool_handler) - - @staticmethod - def bool_handler(*args, **kwargs): - # just assuming bools can have both values - return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] - - @staticmethod - def default_handler(*args, **kwargs): - # many ops are unlikely to show up in optimizable indexing compute, - # so we dont have full coverage - return ValueRanges.unknown() - - def load(self, name: str, index: sympy.Expr): - return ValueRanges.unknown() - - def store(self, name, index, value, mode=None): - return - - def reduction(self, name, dtype, src_dtype, reduction_type, index, value): - return ValueRanges.unknown() - - @classmethod - def index_expr(cls, index, dtype): - assert isinstance(index, ValueRanges) - return cls.to_dtype(index, dtype) - - @staticmethod - def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): - x = ValueRanges.wrap(x) - - if dtype == torch.bool: - if x.is_singleton(): - return ValueRanges.wrap(x.lower != 0) - elif x.is_bool: - return x - elif 0 not in x: - return ValueRanges.wrap(sympy.true) - else: - return ValueRanges(sympy.false, sympy.true) - - def cast(x, dtype): - # dtype is int or float - if dtype.is_floating_point: - return sympy.Float(x) - else: - if x in (int_oo, -int_oo): - return x - try: - return sympy.Integer(x) - except TypeError: - # inf cannot be cast to Integer - return x - - if x.is_bool: - if x.is_singleton(): - val = 1 if x.lower else 0 - return ValueRanges.wrap(cast(val, dtype)) - else: - return ValueRanges(cast(0, dtype), cast(1, dtype)) - else: - # int to float or float to int - return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) - - @staticmethod - def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) - - @staticmethod - def neg(x): - return ValueRanges.decreasing_map(x, operator.neg) - - # TODO: this is slightly inaccurate because truncdiv operates at integer - # precision, but we're going through float truediv which means we can - # potentially lose precision on the bounds - @classmethod - def truncdiv(cls, a, b): - x = cls.truediv(a, b) - if x == ValueRanges.unknown(): - return x - - return cls.trunc(x) - - @classmethod - def sub(cls, a, b): - return cls.add(a, cls.neg(b)) - - def __getattr__(self, name): - log.debug("unhandled ValueRange op %s", name) - return self.default_handler - - def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: