mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Refactor op handlers part 4 (#146255)
This replaces the `__getattr__()` pattern used in remaining OpHandlers with a `DefaultHandler` class defined in part 2. Some compile time wins from this as well: ``` 2025-02-02T19:46:32.2033010Z 2025-02-02T19:46:32.2036607Z WIN: benchmark ('add_loop_inductor', 'compile_time_instruction_count') failed, actual result 29633182927 is -1.71% lower than expected 30150000000 ±1.50% please update the expected results. 2025-02-02T19:46:32.2037575Z 2025-02-02T19:46:32.2037907Z please update all results that changed significantly, and not only the failed ones 2025-02-02T19:46:32.2039291Z PASS: benchmark ('add_loop_inductor_dynamic_gpu', 'compile_time_instruction_count') pass, actual result 43986879172 -1.02% is within expected 44440000000 ±2.50% 2025-02-02T19:46:32.2040131Z 2025-02-02T19:46:32.2041180Z WIN: benchmark ('add_loop_inductor_gpu', 'compile_time_instruction_count') failed, actual result 26246225695 is -1.85% lower than expected 26740000000 ±1.50% please update the expected results. 2025-02-02T19:46:32.2042188Z ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146255 Approved by: https://github.com/shunting314 ghstack dependencies: #146252, #146254
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e31e5932b
commit
403db2faee
@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015
|
||||
|
|
@ -1,6 +1,6 @@
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
@ -9,6 +9,7 @@ from torch._inductor import config
|
||||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
||||
from torch._inductor.index_propagation import SymPyOps, TypedExpr
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .virtualized import StoreMode, V
|
||||
|
||||
|
||||
@ -20,7 +21,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
|
||||
return sympy.Symbol(f"unknown_{count}")
|
||||
|
||||
|
||||
class PreservesZeros(SymPyOps):
|
||||
class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
"""
|
||||
For prologue kernels where the loads are masked, does the final store of this kernel preserve
|
||||
the zeros.
|
||||
@ -54,18 +55,20 @@ class PreservesZeros(SymPyOps):
|
||||
self = V.get_ops_handler()
|
||||
return construct_symbol(next(self.count), torch.int32)
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
from torch._inductor.codegen.common import OpDecompositions
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
|
||||
nonlocal self
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
|
||||
return inner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
|
||||
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
|
||||
@ -88,7 +91,7 @@ class DTypeContainer:
|
||||
is_scalar: bool = False
|
||||
|
||||
|
||||
class RecordLowPrecisionOps:
|
||||
class RecordLowPrecisionOps(DefaultHandler):
|
||||
def __init__(self) -> None:
|
||||
self.low_precision_numeric_op = False
|
||||
self.dtype_prop = DtypePropagationOpsHandler()
|
||||
@ -111,28 +114,31 @@ class RecordLowPrecisionOps:
|
||||
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
return sympy.S.Zero
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer)
|
||||
and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
return out
|
||||
|
||||
return out
|
||||
|
||||
return inner
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
|
||||
|
||||
def can_codegen_without_upcasts(
|
||||
|
@ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
|
||||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..ops_handler import BasicMathOps
|
||||
from ..ops_handler import BasicMathOps, DefaultHandler
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
@ -2263,7 +2263,7 @@ class KernelTemplate:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CSEProxy:
|
||||
class CSEProxy(DefaultHandler):
|
||||
name = "CSEProxy"
|
||||
|
||||
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
|
||||
@ -2272,69 +2272,66 @@ class CSEProxy:
|
||||
self.kernel = kernel
|
||||
self.parent_handler = parent_handler
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
output_idx = 0
|
||||
output_idx = 0
|
||||
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
triton_backend = False
|
||||
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
triton_backend = False
|
||||
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
)
|
||||
output_idx += 1
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
|
||||
return csevar
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
)
|
||||
output_idx += 1
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
|
||||
return inner
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
||||
"""
|
||||
|
@ -31,6 +31,7 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
AutotuneHint,
|
||||
@ -2872,24 +2873,23 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
class CSEProxy:
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
|
||||
def inner(*args, **kwargs):
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
class CSEProxy(DefaultHandler):
|
||||
def _default(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
return inner
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
with helper.indent(), V.set_ops_handler(CSEProxy()):
|
||||
outputs = fn(*args)
|
||||
|
@ -4,7 +4,17 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -15,6 +25,7 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..utils._sympy.symbol import make_symbol, SymT
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .utils import (
|
||||
get_dtype_size,
|
||||
reduction_num_outputs,
|
||||
@ -737,19 +748,16 @@ def canonicalization_prefix() -> str:
|
||||
|
||||
|
||||
# ops handler which computes all the free unbacked symbols for an IR
|
||||
class FreeUnbackedSymbolsOpsHandler:
|
||||
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
||||
symbols: OrderedSet[sympy.Symbol]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.symbols = OrderedSet()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
|
||||
def indirect_indexing(
|
||||
self,
|
||||
@ -791,10 +799,12 @@ class FreeUnbackedSymbolsOpsHandler:
|
||||
body()
|
||||
|
||||
|
||||
def _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
h: FreeUnbackedSymbolsOpsHandler,
|
||||
) -> OpsHandler[None]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def extract_free_unbacked_symbols(
|
||||
|
@ -21,8 +21,9 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||
|
||||
"""
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import sympy
|
||||
@ -32,6 +33,7 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .sizevars import evaluate_expr
|
||||
from .utils import generate_assert
|
||||
from .virtualized import V
|
||||
@ -185,7 +187,7 @@ class IndexPropVar:
|
||||
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
||||
|
||||
|
||||
class IndexPropagation:
|
||||
class IndexPropagation(DefaultHandler):
|
||||
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
||||
|
||||
This aims to maximize the compile time simplification possible, and convert
|
||||
@ -247,19 +249,19 @@ class IndexPropagation:
|
||||
def fallback(
|
||||
self,
|
||||
name: Literal["indirect_indexing"],
|
||||
args: tuple[Any, ...],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> IndexPropVar:
|
||||
...
|
||||
|
||||
@overload
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
...
|
||||
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Fallback to the wrapped handler
|
||||
new_args = [self.unwrap(a) for a in args]
|
||||
@ -267,7 +269,7 @@ class IndexPropagation:
|
||||
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
||||
|
||||
def propagate_sympy(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Build a new SymPy expression from this ops call
|
||||
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
||||
@ -288,22 +290,19 @@ class IndexPropagation:
|
||||
return self.fallback(name, args, kwargs)
|
||||
return IndexPropVar.new_symbolic(new_expr)
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
|
||||
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
|
||||
return inner
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
|
||||
def statically_true(self, e):
|
||||
"""
|
||||
@ -371,3 +370,9 @@ class IndexPropagation:
|
||||
"indirect_indexing", (index, size, check, wrap_neg), {}
|
||||
).value
|
||||
return indirect_var
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
|
||||
pass
|
||||
|
@ -17,6 +17,7 @@ from torch.utils._sympy.symbol import SymT
|
||||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
@ -653,11 +654,11 @@ class LoopBodyBlock:
|
||||
return copy
|
||||
|
||||
|
||||
class CountOps:
|
||||
class CountOps(DefaultHandler):
|
||||
def __init__(self, inner: Any, counts: collections.Counter[str]):
|
||||
self._inner = inner
|
||||
self._counts = counts
|
||||
|
||||
def __getattr__(self, name):
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
self._counts[name] += 1
|
||||
return getattr(self._inner, name)
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
|
@ -64,6 +64,7 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .ops_handler import ( # noqa: F401
|
||||
DefaultHandler,
|
||||
KernelFormatterHandler,
|
||||
MockHandler,
|
||||
OpsHandler,
|
||||
@ -274,18 +275,15 @@ class OpsValue:
|
||||
return ops.bitwise_left_shift(self, n)
|
||||
|
||||
|
||||
class OpsWrapper:
|
||||
class OpsWrapper(DefaultHandler):
|
||||
"""This wraps any returned IR values into an `OpsValue` instance, so that we
|
||||
can overload the magic methods for writing mathematical expressions fluently.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
|
||||
@staticmethod
|
||||
def _unwrap(x):
|
||||
|
Reference in New Issue
Block a user