Revert "[inductor] Refactor op handlers part 4 (#146255)"

This reverts commit 7aced455c542f629ffcd4f79c6af259bb966add8.

Reverted https://github.com/pytorch/pytorch/pull/146255 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/146252 ([comment](https://github.com/pytorch/pytorch/pull/146255#issuecomment-2638258089))
This commit is contained in:
PyTorch MergeBot
2025-02-05 23:24:20 +00:00
parent 49effa0deb
commit 68304dba7a
8 changed files with 148 additions and 165 deletions

View File

@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015

1 add_loop_eager compile_time_instruction_count 3096000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 945100000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18980000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17150000000 17250000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10885050825 0.2
10 update_hint_regression compile_time_instruction_count 1686000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1041000000 0.015
12 symint_sum compile_time_instruction_count 3324000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2028000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5836000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9167000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3863000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10390000000 10340000000 0.015
18
19
20
26
27
28
29
30
31
32
62
63
64
65

View File

@ -1,6 +1,6 @@
import dataclasses
import itertools
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
import sympy
@ -9,7 +9,6 @@ from torch._inductor import config
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
from torch._inductor.index_propagation import SymPyOps, TypedExpr
from .ops_handler import DefaultHandler, OpsHandler
from .virtualized import StoreMode, V
@ -21,7 +20,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
return sympy.Symbol(f"unknown_{count}")
class PreservesZeros(SymPyOps, DefaultHandler):
class PreservesZeros(SymPyOps):
"""
For prologue kernels where the loads are masked, does the final store of this kernel preserve
the zeros.
@ -55,20 +54,18 @@ class PreservesZeros(SymPyOps, DefaultHandler):
self = V.get_ops_handler()
return construct_symbol(next(self.count), torch.int32)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
def __getattr__(self, name: str) -> Callable[..., Any]:
from torch._inductor.codegen.common import OpDecompositions
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
if hasattr(OpDecompositions, name):
return getattr(OpDecompositions, name)(*args, **kwargs).value
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
nonlocal self
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
if TYPE_CHECKING:
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
pass
return inner
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
@ -91,7 +88,7 @@ class DTypeContainer:
is_scalar: bool = False
class RecordLowPrecisionOps(DefaultHandler):
class RecordLowPrecisionOps:
def __init__(self) -> None:
self.low_precision_numeric_op = False
self.dtype_prop = DtypePropagationOpsHandler()
@ -114,31 +111,28 @@ class RecordLowPrecisionOps(DefaultHandler):
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
return sympy.S.Zero
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
def __getattr__(self, name: str) -> Callable[..., Any]:
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
if name == "constant":
out = DTypeContainer(torch.float, is_scalar=True)
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
uses_low_prec = any(
isinstance(dtype_cont, DTypeContainer)
and low_prec_float(dtype_cont.dtype)
for dtype_cont in itertools.chain((out,), args, kwargs.values())
)
return out
if uses_low_prec and name not in self.non_numeric_ops:
self.low_precision_numeric_op = True
return out
if TYPE_CHECKING:
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
pass
def low_prec_float(dtype: torch.dtype) -> bool:
return dtype.is_floating_point and dtype.itemsize < 4
return inner
def can_codegen_without_upcasts(

View File

@ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOps, DefaultHandler
from ..ops_handler import BasicMathOps
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -2253,7 +2253,7 @@ class KernelTemplate:
raise NotImplementedError
class CSEProxy(DefaultHandler):
class CSEProxy:
name = "CSEProxy"
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
@ -2262,66 +2262,69 @@ class CSEProxy(DefaultHandler):
self.kernel = kernel
self.parent_handler = parent_handler
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
bounds = self._bound_variable(name, *args, **kwargs)
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
bounds = self._bound_variable(name, *args, **kwargs)
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
dtype_handler = DtypePropagationOpsHandler()
output_idx = 0
output_idx = 0
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
triton_backend = False
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
def do_cse(v: str) -> CSEVariable:
# cpp backend doesnt set current device - TODO: fix
if V.graph.current_device is not None:
device_str = V.graph.get_current_device_or_throw().type
triton_backend = (
config.cpu_backend == "triton"
if device_str == "cpu"
else config.cuda_backend == "triton"
if device_str != "mps"
else False
)
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
triton_backend = False
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
)
# only triton backend tracks dtype currently
if triton_backend:
if name == "masked":
output_dtype = value.dtype
else:
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
else:
# cpp backend doesnt track dtype yet
output_dtype = None
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
csevar = V.kernel.cse.generate(
V.kernel.compute,
v,
bounds=bounds,
dtype=output_dtype,
)
output_idx += 1
csevar.update_on_args(name, args, kwargs)
nonlocal output_idx
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
from torch._inductor.codegen.triton import triton_type
return csevar
# we tree_map over the output, so we need to fetch corresponding dtype
if isinstance(output_dtype, (list, tuple)):
output_dtype = output_dtype[output_idx]
return pytree.tree_map(do_cse, value)
V.kernel.compute.writeline(
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
)
output_idx += 1
csevar.update_on_args(name, args, kwargs)
return csevar
return pytree.tree_map(do_cse, value)
return inner
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
"""

View File

@ -30,7 +30,6 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..codecache import code_hash, get_path, PyCodeCache
from ..ops_handler import DefaultHandler
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
AutotuneHint,
@ -2849,23 +2848,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype_handler = DtypePropagationOpsHandler()
class CSEProxy(DefaultHandler):
def _default(
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
nonlocal helper_name
helper_name += f"_{name}"
class CSEProxy:
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
def inner(*args, **kwargs):
nonlocal helper_name
helper_name += f"_{name}"
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
output_dtype = getattr(
dtype_handler,
name,
)(*args, **kwargs)
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
return cse.generate(
helper,
getattr(overrides, name)(*args, **kwargs),
dtype=output_dtype,
)
return inner
with helper.indent(), V.set_ops_handler(CSEProxy()):
outputs = fn(*args)

View File

@ -4,17 +4,7 @@ import itertools
import logging
import re
from collections.abc import Sequence
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from unittest.mock import patch
import sympy
@ -25,7 +15,6 @@ from torch.utils._ordered_set import OrderedSet
from ..utils._sympy.symbol import make_symbol, SymT
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .utils import (
get_dtype_size,
reduction_num_outputs,
@ -748,16 +737,19 @@ def canonicalization_prefix() -> str:
# ops handler which computes all the free unbacked symbols for an IR
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
class FreeUnbackedSymbolsOpsHandler:
symbols: OrderedSet[sympy.Symbol]
def __init__(self) -> None:
self.symbols = OrderedSet()
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
def __getattr__(self, name: str) -> Callable[..., Any]:
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
for a in itertools.chain(args, kwargs.values()):
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
self.symbols |= free_unbacked_symbols(a)
return inner
def indirect_indexing(
self,
@ -799,12 +791,10 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
body()
if TYPE_CHECKING:
class _typecheck_FreeUnbackedSymbolsOpsHandler(
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
):
pass
def _typecheck_FreeUnbackedSymbolsOpsHandler(
h: FreeUnbackedSymbolsOpsHandler,
) -> OpsHandler[None]:
return h
def extract_free_unbacked_symbols(

View File

@ -21,9 +21,8 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
from typing import Any, Callable, Literal, Optional, overload, Union
from typing_extensions import TypeAlias
import sympy
@ -33,7 +32,6 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .ops_handler import DefaultHandler, OpsHandler
from .sizevars import evaluate_expr
from .utils import generate_assert
from .virtualized import V
@ -187,7 +185,7 @@ class IndexPropVar:
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
class IndexPropagation(DefaultHandler):
class IndexPropagation:
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
This aims to maximize the compile time simplification possible, and convert
@ -249,19 +247,19 @@ class IndexPropagation(DefaultHandler):
def fallback(
self,
name: Literal["indirect_indexing"],
args: Sequence[Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
@overload
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
...
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
# Fallback to the wrapped handler
new_args = [self.unwrap(a) for a in args]
@ -269,7 +267,7 @@ class IndexPropagation(DefaultHandler):
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
def propagate_sympy(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
# Build a new SymPy expression from this ops call
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
@ -290,19 +288,22 @@ class IndexPropagation(DefaultHandler):
return self.fallback(name, args, kwargs)
return IndexPropVar.new_symbolic(new_expr)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
if not hasattr(SymPyOps, name):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
var_arguments = [
a
for a in itertools.chain(args, kwargs.values())
if isinstance(a, IndexPropVar)
]
if not all(v.is_symbolic for v in var_arguments):
return self.fallback(name, args, kwargs)
return self.propagate_sympy(name, args, kwargs)
return self.propagate_sympy(name, args, kwargs)
return inner
def statically_true(self, e):
"""
@ -370,9 +371,3 @@ class IndexPropagation(DefaultHandler):
"indirect_indexing", (index, size, check, wrap_neg), {}
).value
return indirect_var
if TYPE_CHECKING:
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
pass

View File

@ -17,7 +17,6 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import DefaultHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
@ -654,11 +653,11 @@ class LoopBodyBlock:
return copy
class CountOps(DefaultHandler):
class CountOps:
def __init__(self, inner: Any, counts: collections.Counter[str]):
self._inner = inner
self._counts = counts
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
def __getattr__(self, name):
self._counts[name] += 1
return getattr(self._inner, name)(*args, **kwargs)
return getattr(self._inner, name)

View File

@ -64,7 +64,6 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
from torch.utils._ordered_set import OrderedSet
from .ops_handler import ( # noqa: F401
DefaultHandler,
KernelFormatterHandler,
MockHandler,
OpsHandler,
@ -275,15 +274,18 @@ class OpsValue:
return ops.bitwise_left_shift(self, n)
class OpsWrapper(DefaultHandler):
class OpsWrapper:
"""This wraps any returned IR values into an `OpsValue` instance, so that we
can overload the magic methods for writing mathematical expressions fluently.
"""
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
def __getattr__(self, name):
def inner(*args, **kwargs):
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
return inner
@staticmethod
def _unwrap(x):