[inductor] Refactor op handlers part 4 (#146255)

This replaces the `__getattr__()` pattern used in remaining OpHandlers with a `DefaultHandler` class defined in part 2.

Some compile time wins from this as well:
```
2025-02-02T19:46:32.2033010Z
2025-02-02T19:46:32.2036607Z WIN: benchmark ('add_loop_inductor', 'compile_time_instruction_count') failed, actual result 29633182927 is -1.71% lower than expected 30150000000 ±1.50% please update the expected results.
2025-02-02T19:46:32.2037575Z
2025-02-02T19:46:32.2037907Z please update all results that changed significantly, and not only the failed ones
2025-02-02T19:46:32.2039291Z PASS: benchmark ('add_loop_inductor_dynamic_gpu', 'compile_time_instruction_count') pass, actual result 43986879172 -1.02% is within expected 44440000000 ±2.50%
2025-02-02T19:46:32.2040131Z
2025-02-02T19:46:32.2041180Z WIN: benchmark ('add_loop_inductor_gpu', 'compile_time_instruction_count') failed, actual result 26246225695 is -1.85% lower than expected 26740000000 ±1.50% please update the expected results.
2025-02-02T19:46:32.2042188Z
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146255
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254
This commit is contained in:
Jason Ansel
2025-02-07 13:32:54 -08:00
committed by PyTorch MergeBot
parent 0e31e5932b
commit 403db2faee
8 changed files with 164 additions and 147 deletions

View File

@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015

1 add_loop_eager compile_time_instruction_count 3096000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 945100000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18980000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17250000000 17150000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10885050825 0.2
10 update_hint_regression compile_time_instruction_count 1686000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1041000000 0.015
12 symint_sum compile_time_instruction_count 3324000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2028000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5836000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9167000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3863000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10340000000 10390000000 0.015
18
19
20
26
27
28
29
30
31
32
62
63
64
65

View File

@ -1,6 +1,6 @@
import dataclasses
import itertools
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy
@ -9,6 +9,7 @@ from torch._inductor import config
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._inductor.index_propagation import SymPyOps, TypedExpr
from .ops_handler import DefaultHandler, OpsHandler
from .virtualized import StoreMode, V
@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
return sympy.Symbol(f"unknown_{count}")
class PreservesZeros(SymPyOps):
class PreservesZeros(SymPyOps, DefaultHandler):
"""
For prologue kernels where the loads are masked, does the final store of this kernel preserve
the zeros.
@ -54,18 +55,20 @@ class PreservesZeros(SymPyOps):
self = V.get_ops_handler()
return construct_symbol(next(self.count), torch.int32)
def __getattr__(self, name: str) -> Callable[..., Any]:
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
from torch._inductor.codegen.common import OpDecompositions
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
nonlocal self
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
return inner
if TYPE_CHECKING:
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
pass
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
@ -88,7 +91,7 @@ class DTypeContainer:
is_scalar: bool = False
class RecordLowPrecisionOps:
class RecordLowPrecisionOps(DefaultHandler):
def __init__(self) -> None:
self.low_precision_numeric_op = False
self.dtype_prop = DtypePropagationOpsHandler()
@ -111,28 +114,31 @@ class RecordLowPrecisionOps:
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
return sympy.S.Zero
def __getattr__(self, name: str) -> Callable[..., Any]:
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer)
and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
return out
return out
return inner
if TYPE_CHECKING:
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
pass
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
def can_codegen_without_upcasts(

View File

@ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOps
from ..ops_handler import BasicMathOps, DefaultHandler
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -2263,7 +2263,7 @@ class KernelTemplate:
raise NotImplementedError
class CSEProxy:
class CSEProxy(DefaultHandler):
name = "CSEProxy"
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
@ -2272,69 +2272,66 @@ class CSEProxy:
self.kernel = kernel
self.parent_handler = parent_handler
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
bounds = self._bound_variable(name, *args, **kwargs)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
bounds = self._bound_variable(name, *args, **kwargs)
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
output_idx = 0
output_idx = 0
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
triton_backend = False
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
triton_backend = False
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
)
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
csevar.update_on_args(name, args, kwargs)
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
return csevar
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
return pytree.tree_map(do_cse, value)
csevar.update_on_args(name, args, kwargs)
return inner
return csevar
return pytree.tree_map(do_cse, value)
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
"""

View File

@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..codecache import code_hash, get_path, PyCodeCache
from ..ops_handler import DefaultHandler
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
AutotuneHint,
@ -2872,24 +2873,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype_handler = DtypePropagationOpsHandler()
class CSEProxy:
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
def inner(*args, **kwargs):
nonlocal helper_name
helper_name += f"_{name}"
class CSEProxy(DefaultHandler):
def _default(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
nonlocal helper_name
helper_name += f"_{name}"
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
return inner
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
with helper.indent(), V.set_ops_handler(CSEProxy()):
outputs = fn(*args)

View File

@ -4,7 +4,17 @@ import itertools
import logging
import re
from collections.abc import Sequence
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from unittest.mock import patch
import sympy
@ -15,6 +25,7 @@ from torch.utils._ordered_set import OrderedSet
from ..utils._sympy.symbol import make_symbol, SymT
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .utils import (
get_dtype_size,
reduction_num_outputs,
@ -737,19 +748,16 @@ def canonicalization_prefix() -> str:
# ops handler which computes all the free unbacked symbols for an IR
class FreeUnbackedSymbolsOpsHandler:
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
symbols: OrderedSet[sympy.Symbol]
def __init__(self) -> None:
self.symbols = OrderedSet()
def __getattr__(self, name: str) -> Callable[..., Any]:
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
def indirect_indexing(
self,
@ -791,10 +799,12 @@ class FreeUnbackedSymbolsOpsHandler:
body()
def _typecheck_FreeUnbackedSymbolsOpsHandler(
h: FreeUnbackedSymbolsOpsHandler,
) -> OpsHandler[None]:
return h
if TYPE_CHECKING:
class _typecheck_FreeUnbackedSymbolsOpsHandler(
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
):
pass
def extract_free_unbacked_symbols(

View File

@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, overload, Union
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import sympy
@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .ops_handler import DefaultHandler, OpsHandler
from .sizevars import evaluate_expr
from .utils import generate_assert
from .virtualized import V
@ -185,7 +187,7 @@ class IndexPropVar:
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
class IndexPropagation:
class IndexPropagation(DefaultHandler):
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
This aims to maximize the compile time simplification possible, and convert
@ -247,19 +249,19 @@ class IndexPropagation:
def fallback(
self,
name: Literal["indirect_indexing"],
args: tuple[Any, ...],
args: Sequence[Any],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
@overload
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
...
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
# Fallback to the wrapped handler
new_args = [self.unwrap(a) for a in args]
@ -267,7 +269,7 @@ class IndexPropagation:
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
def propagate_sympy(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
# Build a new SymPy expression from this ops call
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
@ -288,22 +290,19 @@ class IndexPropagation:
return self.fallback(name, args, kwargs)
return IndexPropVar.new_symbolic(new_expr)
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
return self.propagate_sympy(name, args, kwargs)
return inner
return self.propagate_sympy(name, args, kwargs)
def statically_true(self, e):
"""
@ -371,3 +370,9 @@ class IndexPropagation:
"indirect_indexing", (index, size, check, wrap_neg), {}
).value
return indirect_var
if TYPE_CHECKING:
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
pass

View File

@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
@ -653,11 +654,11 @@ class LoopBodyBlock:
return copy
class CountOps:
class CountOps(DefaultHandler):
def __init__(self, inner: Any, counts: collections.Counter[str]):
self._inner = inner
self._counts = counts
def __getattr__(self, name):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
self._counts[name] += 1
return getattr(self._inner, name)
return getattr(self._inner, name)(*args, **kwargs)

View File

@ -64,6 +64,7 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
from torch.utils._ordered_set import OrderedSet
from .ops_handler import ( # noqa: F401
DefaultHandler,
KernelFormatterHandler,
MockHandler,
OpsHandler,
@ -274,18 +275,15 @@ class OpsValue:
return ops.bitwise_left_shift(self, n)
class OpsWrapper:
class OpsWrapper(DefaultHandler):
"""This wraps any returned IR values into an `OpsValue` instance, so that we
can overload the magic methods for writing mathematical expressions fluently.
"""
def __getattr__(self, name):
def inner(*args, **kwargs):
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
@staticmethod
def _unwrap(x):