mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[inductor] Refactor op handlers part 4 (#146255)"
This reverts commit 7aced455c542f629ffcd4f79c6af259bb966add8. Reverted https://github.com/pytorch/pytorch/pull/146255 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/146252 ([comment](https://github.com/pytorch/pytorch/pull/146255#issuecomment-2638258089))
This commit is contained in:
@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,29630000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
@ -9,7 +9,6 @@ from torch._inductor import config
|
||||
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
|
||||
from torch._inductor.index_propagation import SymPyOps, TypedExpr
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .virtualized import StoreMode, V
|
||||
|
||||
|
||||
@ -21,7 +20,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol:
|
||||
return sympy.Symbol(f"unknown_{count}")
|
||||
|
||||
|
||||
class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
class PreservesZeros(SymPyOps):
|
||||
"""
|
||||
For prologue kernels where the loads are masked, does the final store of this kernel preserve
|
||||
the zeros.
|
||||
@ -55,20 +54,18 @@ class PreservesZeros(SymPyOps, DefaultHandler):
|
||||
self = V.get_ops_handler()
|
||||
return construct_symbol(next(self.count), torch.int32)
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
from torch._inductor.codegen.common import OpDecompositions
|
||||
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
def inner(*args: Any, **kwargs: Any) -> TypedExpr:
|
||||
if hasattr(OpDecompositions, name):
|
||||
return getattr(OpDecompositions, name)(*args, **kwargs).value
|
||||
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
nonlocal self
|
||||
dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
return TypedExpr(construct_symbol(next(self.count), dtype), dtype)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]):
|
||||
pass
|
||||
return inner
|
||||
|
||||
|
||||
def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool:
|
||||
@ -91,7 +88,7 @@ class DTypeContainer:
|
||||
is_scalar: bool = False
|
||||
|
||||
|
||||
class RecordLowPrecisionOps(DefaultHandler):
|
||||
class RecordLowPrecisionOps:
|
||||
def __init__(self) -> None:
|
||||
self.low_precision_numeric_op = False
|
||||
self.dtype_prop = DtypePropagationOpsHandler()
|
||||
@ -114,31 +111,28 @@ class RecordLowPrecisionOps(DefaultHandler):
|
||||
def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr:
|
||||
return sympy.S.Zero
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
def inner(*args: Any, **kwargs: Any) -> DTypeContainer:
|
||||
out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs)
|
||||
out = DTypeContainer(out_dtype, is_scalar=(name == "constant"))
|
||||
if name == "constant":
|
||||
out = DTypeContainer(torch.float, is_scalar=True)
|
||||
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
uses_low_prec = any(
|
||||
isinstance(dtype_cont, DTypeContainer)
|
||||
and low_prec_float(dtype_cont.dtype)
|
||||
for dtype_cont in itertools.chain((out,), args, kwargs.values())
|
||||
)
|
||||
|
||||
return out
|
||||
if uses_low_prec and name not in self.non_numeric_ops:
|
||||
self.low_precision_numeric_op = True
|
||||
|
||||
return out
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
|
||||
def low_prec_float(dtype: torch.dtype) -> bool:
|
||||
return dtype.is_floating_point and dtype.itemsize < 4
|
||||
return inner
|
||||
|
||||
|
||||
def can_codegen_without_upcasts(
|
||||
|
||||
@ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
|
||||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..ops_handler import BasicMathOps, DefaultHandler
|
||||
from ..ops_handler import BasicMathOps
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
@ -2253,7 +2253,7 @@ class KernelTemplate:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CSEProxy(DefaultHandler):
|
||||
class CSEProxy:
|
||||
name = "CSEProxy"
|
||||
|
||||
def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]):
|
||||
@ -2262,66 +2262,69 @@ class CSEProxy(DefaultHandler):
|
||||
self.kernel = kernel
|
||||
self.parent_handler = parent_handler
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
||||
def inner(*args: Any, **kwargs: Any) -> CSEVariable:
|
||||
bounds = self._bound_variable(name, *args, **kwargs)
|
||||
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
value = getattr(self.parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
output_idx = 0
|
||||
output_idx = 0
|
||||
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
triton_backend = False
|
||||
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
def do_cse(v: str) -> CSEVariable:
|
||||
# cpp backend doesnt set current device - TODO: fix
|
||||
if V.graph.current_device is not None:
|
||||
device_str = V.graph.get_current_device_or_throw().type
|
||||
triton_backend = (
|
||||
config.cpu_backend == "triton"
|
||||
if device_str == "cpu"
|
||||
else config.cuda_backend == "triton"
|
||||
if device_str != "mps"
|
||||
else False
|
||||
)
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
triton_backend = False
|
||||
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
# only triton backend tracks dtype currently
|
||||
if triton_backend:
|
||||
if name == "masked":
|
||||
output_dtype = value.dtype
|
||||
else:
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
else:
|
||||
# cpp backend doesnt track dtype yet
|
||||
output_dtype = None
|
||||
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
csevar = V.kernel.cse.generate(
|
||||
V.kernel.compute,
|
||||
v,
|
||||
bounds=bounds,
|
||||
dtype=output_dtype,
|
||||
)
|
||||
output_idx += 1
|
||||
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
nonlocal output_idx
|
||||
if config.test_configs.runtime_triton_dtype_assert and triton_backend:
|
||||
from torch._inductor.codegen.triton import triton_type
|
||||
|
||||
return csevar
|
||||
# we tree_map over the output, so we need to fetch corresponding dtype
|
||||
if isinstance(output_dtype, (list, tuple)):
|
||||
output_dtype = output_dtype[output_idx]
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
V.kernel.compute.writeline(
|
||||
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
|
||||
)
|
||||
output_idx += 1
|
||||
|
||||
csevar.update_on_args(name, args, kwargs)
|
||||
|
||||
return csevar
|
||||
|
||||
return pytree.tree_map(do_cse, value)
|
||||
|
||||
return inner
|
||||
|
||||
def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]:
|
||||
"""
|
||||
|
||||
@ -30,7 +30,6 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..ops_handler import DefaultHandler
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
AutotuneHint,
|
||||
@ -2849,23 +2848,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
dtype_handler = DtypePropagationOpsHandler()
|
||||
|
||||
class CSEProxy(DefaultHandler):
|
||||
def _default(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
class CSEProxy:
|
||||
def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
|
||||
def inner(*args, **kwargs):
|
||||
nonlocal helper_name
|
||||
helper_name += f"_{name}"
|
||||
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
output_dtype = getattr(
|
||||
dtype_handler,
|
||||
name,
|
||||
)(*args, **kwargs)
|
||||
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
return cse.generate(
|
||||
helper,
|
||||
getattr(overrides, name)(*args, **kwargs),
|
||||
dtype=output_dtype,
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
with helper.indent(), V.set_ops_handler(CSEProxy()):
|
||||
outputs = fn(*args)
|
||||
|
||||
@ -4,17 +4,7 @@ import itertools
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -25,7 +15,6 @@ from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..utils._sympy.symbol import make_symbol, SymT
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .utils import (
|
||||
get_dtype_size,
|
||||
reduction_num_outputs,
|
||||
@ -748,16 +737,19 @@ def canonicalization_prefix() -> str:
|
||||
|
||||
|
||||
# ops handler which computes all the free unbacked symbols for an IR
|
||||
class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
||||
class FreeUnbackedSymbolsOpsHandler:
|
||||
symbols: OrderedSet[sympy.Symbol]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.symbols = OrderedSet()
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None:
|
||||
for a in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
|
||||
self.symbols |= free_unbacked_symbols(a)
|
||||
|
||||
return inner
|
||||
|
||||
def indirect_indexing(
|
||||
self,
|
||||
@ -799,12 +791,10 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler):
|
||||
body()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
FreeUnbackedSymbolsOpsHandler, OpsHandler[None]
|
||||
):
|
||||
pass
|
||||
def _typecheck_FreeUnbackedSymbolsOpsHandler(
|
||||
h: FreeUnbackedSymbolsOpsHandler,
|
||||
) -> OpsHandler[None]:
|
||||
return h
|
||||
|
||||
|
||||
def extract_free_unbacked_symbols(
|
||||
|
||||
@ -21,9 +21,8 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||
|
||||
"""
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import sympy
|
||||
@ -33,7 +32,6 @@ from torch._prims_common import dtype_to_type, is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
|
||||
from .ops_handler import DefaultHandler, OpsHandler
|
||||
from .sizevars import evaluate_expr
|
||||
from .utils import generate_assert
|
||||
from .virtualized import V
|
||||
@ -187,7 +185,7 @@ class IndexPropVar:
|
||||
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
||||
|
||||
|
||||
class IndexPropagation(DefaultHandler):
|
||||
class IndexPropagation:
|
||||
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
||||
|
||||
This aims to maximize the compile time simplification possible, and convert
|
||||
@ -249,19 +247,19 @@ class IndexPropagation(DefaultHandler):
|
||||
def fallback(
|
||||
self,
|
||||
name: Literal["indirect_indexing"],
|
||||
args: Sequence[Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> IndexPropVar:
|
||||
...
|
||||
|
||||
@overload
|
||||
def fallback(
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
...
|
||||
|
||||
def fallback(
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Fallback to the wrapped handler
|
||||
new_args = [self.unwrap(a) for a in args]
|
||||
@ -269,7 +267,7 @@ class IndexPropagation(DefaultHandler):
|
||||
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
||||
|
||||
def propagate_sympy(
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Build a new SymPy expression from this ops call
|
||||
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
||||
@ -290,19 +288,22 @@ class IndexPropagation(DefaultHandler):
|
||||
return self.fallback(name, args, kwargs)
|
||||
return IndexPropVar.new_symbolic(new_expr)
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
|
||||
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
|
||||
if not hasattr(SymPyOps, name):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
var_arguments = [
|
||||
a
|
||||
for a in itertools.chain(args, kwargs.values())
|
||||
if isinstance(a, IndexPropVar)
|
||||
]
|
||||
if not all(v.is_symbolic for v in var_arguments):
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
return self.propagate_sympy(name, args, kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
def statically_true(self, e):
|
||||
"""
|
||||
@ -370,9 +371,3 @@ class IndexPropagation(DefaultHandler):
|
||||
"indirect_indexing", (index, size, check, wrap_neg), {}
|
||||
).value
|
||||
return indirect_var
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]):
|
||||
pass
|
||||
|
||||
@ -17,7 +17,6 @@ from torch.utils._sympy.symbol import SymT
|
||||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import DefaultHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
@ -654,11 +653,11 @@ class LoopBodyBlock:
|
||||
return copy
|
||||
|
||||
|
||||
class CountOps(DefaultHandler):
|
||||
class CountOps:
|
||||
def __init__(self, inner: Any, counts: collections.Counter[str]):
|
||||
self._inner = inner
|
||||
self._counts = counts
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
def __getattr__(self, name):
|
||||
self._counts[name] += 1
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
return getattr(self._inner, name)
|
||||
|
||||
@ -64,7 +64,6 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .ops_handler import ( # noqa: F401
|
||||
DefaultHandler,
|
||||
KernelFormatterHandler,
|
||||
MockHandler,
|
||||
OpsHandler,
|
||||
@ -275,15 +274,18 @@ class OpsValue:
|
||||
return ops.bitwise_left_shift(self, n)
|
||||
|
||||
|
||||
class OpsWrapper(DefaultHandler):
|
||||
class OpsWrapper:
|
||||
"""This wraps any returned IR values into an `OpsValue` instance, so that we
|
||||
can overload the magic methods for writing mathematical expressions fluently.
|
||||
"""
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
new_args = [OpsWrapper._unwrap(a) for a in args]
|
||||
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
|
||||
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
|
||||
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def _unwrap(x):
|
||||
|
||||
Reference in New Issue
Block a user