mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Revert "[inductor] Refactor op handlers part 4 (#146255)"
This reverts commit 7aced455c542f629ffcd4f79c6af259bb966add8. Reverted https://github.com/pytorch/pytorch/pull/146255 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/146252 ([comment](https://github.com/pytorch/pytorch/pull/146255#issuecomment-2638258089))
This commit is contained in:
		| @ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor,compile_time_instruction_count,29630000000,0.015 | ||||
| add_loop_inductor,compile_time_instruction_count,30150000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43980000000,0.025 | ||||
| add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025 | ||||
|  | ||||
|  | ||||
|  | ||||
| add_loop_inductor_gpu,compile_time_instruction_count,26240000000,0.015 | ||||
| add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -26,7 +26,7 @@ basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000, | ||||
|  | ||||
|  | ||||
|  | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17150000000,0.015 | ||||
| basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000, | ||||
|  | ||||
|  | ||||
|  | ||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10390000000,0.015 | ||||
| aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10340000000,0.015 | ||||
|  | ||||
| 
 | 
| @ -1,6 +1,6 @@ | ||||
| import dataclasses | ||||
| import itertools | ||||
| from typing import Any, Optional, TYPE_CHECKING | ||||
| from typing import Any, Callable, Optional, TYPE_CHECKING | ||||
|  | ||||
| import sympy | ||||
|  | ||||
| @ -9,7 +9,6 @@ from torch._inductor import config | ||||
| from torch._inductor.dtype_propagation import DtypePropagationOpsHandler | ||||
| from torch._inductor.index_propagation import SymPyOps, TypedExpr | ||||
|  | ||||
| from .ops_handler import DefaultHandler, OpsHandler | ||||
| from .virtualized import StoreMode, V | ||||
|  | ||||
|  | ||||
| @ -21,7 +20,7 @@ def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol: | ||||
|     return sympy.Symbol(f"unknown_{count}") | ||||
|  | ||||
|  | ||||
| class PreservesZeros(SymPyOps, DefaultHandler): | ||||
| class PreservesZeros(SymPyOps): | ||||
|     """ | ||||
|     For prologue kernels where the loads are masked, does the final store of this kernel preserve | ||||
|     the zeros. | ||||
| @ -55,20 +54,18 @@ class PreservesZeros(SymPyOps, DefaultHandler): | ||||
|         self = V.get_ops_handler() | ||||
|         return construct_symbol(next(self.count), torch.int32) | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|     def __getattr__(self, name: str) -> Callable[..., Any]: | ||||
|         from torch._inductor.codegen.common import OpDecompositions | ||||
|  | ||||
|         if hasattr(OpDecompositions, name): | ||||
|             return getattr(OpDecompositions, name)(*args, **kwargs).value | ||||
|         def inner(*args: Any, **kwargs: Any) -> TypedExpr: | ||||
|             if hasattr(OpDecompositions, name): | ||||
|                 return getattr(OpDecompositions, name)(*args, **kwargs).value | ||||
|  | ||||
|         dtype = getattr(self.dtype_prop, name)(*args, **kwargs) | ||||
|         return TypedExpr(construct_symbol(next(self.count), dtype), dtype) | ||||
|             nonlocal self | ||||
|             dtype = getattr(self.dtype_prop, name)(*args, **kwargs) | ||||
|             return TypedExpr(construct_symbol(next(self.count), dtype), dtype) | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|  | ||||
|     class _typecheck_PreservesZeros(PreservesZeros, OpsHandler[Any]): | ||||
|         pass | ||||
|         return inner | ||||
|  | ||||
|  | ||||
| def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: | ||||
| @ -91,7 +88,7 @@ class DTypeContainer: | ||||
|     is_scalar: bool = False | ||||
|  | ||||
|  | ||||
| class RecordLowPrecisionOps(DefaultHandler): | ||||
| class RecordLowPrecisionOps: | ||||
|     def __init__(self) -> None: | ||||
|         self.low_precision_numeric_op = False | ||||
|         self.dtype_prop = DtypePropagationOpsHandler() | ||||
| @ -114,31 +111,28 @@ class RecordLowPrecisionOps(DefaultHandler): | ||||
|     def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: | ||||
|         return sympy.S.Zero | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|         out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) | ||||
|         out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) | ||||
|         if name == "constant": | ||||
|             out = DTypeContainer(torch.float, is_scalar=True) | ||||
|     def __getattr__(self, name: str) -> Callable[..., Any]: | ||||
|         def low_prec_float(dtype: torch.dtype) -> bool: | ||||
|             return dtype.is_floating_point and dtype.itemsize < 4 | ||||
|  | ||||
|         uses_low_prec = any( | ||||
|             isinstance(dtype_cont, DTypeContainer) and low_prec_float(dtype_cont.dtype) | ||||
|             for dtype_cont in itertools.chain((out,), args, kwargs.values()) | ||||
|         ) | ||||
|         def inner(*args: Any, **kwargs: Any) -> DTypeContainer: | ||||
|             out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) | ||||
|             out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) | ||||
|             if name == "constant": | ||||
|                 out = DTypeContainer(torch.float, is_scalar=True) | ||||
|  | ||||
|         if uses_low_prec and name not in self.non_numeric_ops: | ||||
|             self.low_precision_numeric_op = True | ||||
|             uses_low_prec = any( | ||||
|                 isinstance(dtype_cont, DTypeContainer) | ||||
|                 and low_prec_float(dtype_cont.dtype) | ||||
|                 for dtype_cont in itertools.chain((out,), args, kwargs.values()) | ||||
|             ) | ||||
|  | ||||
|         return out | ||||
|             if uses_low_prec and name not in self.non_numeric_ops: | ||||
|                 self.low_precision_numeric_op = True | ||||
|  | ||||
|             return out | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|  | ||||
|     class _typecheck_RecordLowPrecisionOps(RecordLowPrecisionOps, OpsHandler[Any]): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def low_prec_float(dtype: torch.dtype) -> bool: | ||||
|     return dtype.is_floating_point and dtype.itemsize < 4 | ||||
|         return inner | ||||
|  | ||||
|  | ||||
| def can_codegen_without_upcasts( | ||||
|  | ||||
| @ -41,7 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val | ||||
|  | ||||
| from .. import config, metrics | ||||
| from ..dtype_propagation import DtypePropagationOpsHandler | ||||
| from ..ops_handler import BasicMathOps, DefaultHandler | ||||
| from ..ops_handler import BasicMathOps | ||||
| from ..utils import ( | ||||
|     boolean_ops, | ||||
|     DeferredLineBase, | ||||
| @ -2253,7 +2253,7 @@ class KernelTemplate: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class CSEProxy(DefaultHandler): | ||||
| class CSEProxy: | ||||
|     name = "CSEProxy" | ||||
|  | ||||
|     def __init__(self, kernel: Kernel[Any], parent_handler: OpsHandler[Any]): | ||||
| @ -2262,66 +2262,69 @@ class CSEProxy(DefaultHandler): | ||||
|         self.kernel = kernel | ||||
|         self.parent_handler = parent_handler | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|         bounds = self._bound_variable(name, *args, **kwargs) | ||||
|     def __getattr__(self, name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc] | ||||
|         def inner(*args: Any, **kwargs: Any) -> CSEVariable: | ||||
|             bounds = self._bound_variable(name, *args, **kwargs) | ||||
|  | ||||
|         value = getattr(self.parent_handler, name)(*args, **kwargs)  # type: ignore[has-type] | ||||
|         dtype_handler = DtypePropagationOpsHandler() | ||||
|             value = getattr(self.parent_handler, name)(*args, **kwargs)  # type: ignore[has-type] | ||||
|             dtype_handler = DtypePropagationOpsHandler() | ||||
|  | ||||
|         output_idx = 0 | ||||
|             output_idx = 0 | ||||
|  | ||||
|         def do_cse(v: str) -> CSEVariable: | ||||
|             # cpp backend doesnt set current device - TODO: fix | ||||
|             if V.graph.current_device is not None: | ||||
|                 device_str = V.graph.get_current_device_or_throw().type | ||||
|                 triton_backend = ( | ||||
|                     config.cpu_backend == "triton" | ||||
|                     if device_str == "cpu" | ||||
|                     else config.cuda_backend == "triton" | ||||
|                     if device_str != "mps" | ||||
|                     else False | ||||
|                 ) | ||||
|             else: | ||||
|                 triton_backend = False | ||||
|  | ||||
|             # only triton backend tracks dtype currently | ||||
|             if triton_backend: | ||||
|                 if name == "masked": | ||||
|                     output_dtype = value.dtype | ||||
|             def do_cse(v: str) -> CSEVariable: | ||||
|                 # cpp backend doesnt set current device - TODO: fix | ||||
|                 if V.graph.current_device is not None: | ||||
|                     device_str = V.graph.get_current_device_or_throw().type | ||||
|                     triton_backend = ( | ||||
|                         config.cpu_backend == "triton" | ||||
|                         if device_str == "cpu" | ||||
|                         else config.cuda_backend == "triton" | ||||
|                         if device_str != "mps" | ||||
|                         else False | ||||
|                     ) | ||||
|                 else: | ||||
|                     output_dtype = getattr( | ||||
|                         dtype_handler, | ||||
|                         name, | ||||
|                     )(*args, **kwargs) | ||||
|             else: | ||||
|                 # cpp backend doesnt track dtype yet | ||||
|                 output_dtype = None | ||||
|                     triton_backend = False | ||||
|  | ||||
|             csevar = V.kernel.cse.generate( | ||||
|                 V.kernel.compute, | ||||
|                 v, | ||||
|                 bounds=bounds, | ||||
|                 dtype=output_dtype, | ||||
|             ) | ||||
|                 # only triton backend tracks dtype currently | ||||
|                 if triton_backend: | ||||
|                     if name == "masked": | ||||
|                         output_dtype = value.dtype | ||||
|                     else: | ||||
|                         output_dtype = getattr( | ||||
|                             dtype_handler, | ||||
|                             name, | ||||
|                         )(*args, **kwargs) | ||||
|                 else: | ||||
|                     # cpp backend doesnt track dtype yet | ||||
|                     output_dtype = None | ||||
|  | ||||
|             nonlocal output_idx | ||||
|             if config.test_configs.runtime_triton_dtype_assert and triton_backend: | ||||
|                 from torch._inductor.codegen.triton import triton_type | ||||
|  | ||||
|                 # we tree_map over the output, so we need to fetch corresponding dtype | ||||
|                 if isinstance(output_dtype, (list, tuple)): | ||||
|                     output_dtype = output_dtype[output_idx] | ||||
|  | ||||
|                 V.kernel.compute.writeline( | ||||
|                     f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" | ||||
|                 csevar = V.kernel.cse.generate( | ||||
|                     V.kernel.compute, | ||||
|                     v, | ||||
|                     bounds=bounds, | ||||
|                     dtype=output_dtype, | ||||
|                 ) | ||||
|             output_idx += 1 | ||||
|  | ||||
|             csevar.update_on_args(name, args, kwargs) | ||||
|                 nonlocal output_idx | ||||
|                 if config.test_configs.runtime_triton_dtype_assert and triton_backend: | ||||
|                     from torch._inductor.codegen.triton import triton_type | ||||
|  | ||||
|             return csevar | ||||
|                     # we tree_map over the output, so we need to fetch corresponding dtype | ||||
|                     if isinstance(output_dtype, (list, tuple)): | ||||
|                         output_dtype = output_dtype[output_idx] | ||||
|  | ||||
|         return pytree.tree_map(do_cse, value) | ||||
|                     V.kernel.compute.writeline( | ||||
|                         f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})" | ||||
|                     ) | ||||
|                 output_idx += 1 | ||||
|  | ||||
|                 csevar.update_on_args(name, args, kwargs) | ||||
|  | ||||
|                 return csevar | ||||
|  | ||||
|             return pytree.tree_map(do_cse, value) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[Any]: | ||||
|         """ | ||||
|  | ||||
| @ -30,7 +30,6 @@ from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_ty | ||||
| from ...utils._sympy.value_ranges import ValueRanges | ||||
| from .. import config, ir, metrics | ||||
| from ..codecache import code_hash, get_path, PyCodeCache | ||||
| from ..ops_handler import DefaultHandler | ||||
| from ..runtime.benchmarking import benchmarker | ||||
| from ..runtime.hints import ( | ||||
|     AutotuneHint, | ||||
| @ -2849,23 +2848,24 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): | ||||
|  | ||||
|         dtype_handler = DtypePropagationOpsHandler() | ||||
|  | ||||
|         class CSEProxy(DefaultHandler): | ||||
|             def _default( | ||||
|                 self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] | ||||
|             ) -> Any: | ||||
|                 nonlocal helper_name | ||||
|                 helper_name += f"_{name}" | ||||
|         class CSEProxy: | ||||
|             def __getattr__(self, name: str) -> Callable[..., CSEVariable]: | ||||
|                 def inner(*args, **kwargs): | ||||
|                     nonlocal helper_name | ||||
|                     helper_name += f"_{name}" | ||||
|  | ||||
|                 output_dtype = getattr( | ||||
|                     dtype_handler, | ||||
|                     name, | ||||
|                 )(*args, **kwargs) | ||||
|                     output_dtype = getattr( | ||||
|                         dtype_handler, | ||||
|                         name, | ||||
|                     )(*args, **kwargs) | ||||
|  | ||||
|                 return cse.generate( | ||||
|                     helper, | ||||
|                     getattr(overrides, name)(*args, **kwargs), | ||||
|                     dtype=output_dtype, | ||||
|                 ) | ||||
|                     return cse.generate( | ||||
|                         helper, | ||||
|                         getattr(overrides, name)(*args, **kwargs), | ||||
|                         dtype=output_dtype, | ||||
|                     ) | ||||
|  | ||||
|                 return inner | ||||
|  | ||||
|         with helper.indent(), V.set_ops_handler(CSEProxy()): | ||||
|             outputs = fn(*args) | ||||
|  | ||||
| @ -4,17 +4,7 @@ import itertools | ||||
| import logging | ||||
| import re | ||||
| from collections.abc import Sequence | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|     Iterable, | ||||
|     List, | ||||
|     Optional, | ||||
|     Tuple, | ||||
|     TYPE_CHECKING, | ||||
|     TypeVar, | ||||
|     Union, | ||||
| ) | ||||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import sympy | ||||
| @ -25,7 +15,6 @@ from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
| from ..utils._sympy.symbol import make_symbol, SymT | ||||
| from .codegen.common import index_prevent_reordering | ||||
| from .ops_handler import DefaultHandler | ||||
| from .utils import ( | ||||
|     get_dtype_size, | ||||
|     reduction_num_outputs, | ||||
| @ -748,16 +737,19 @@ def canonicalization_prefix() -> str: | ||||
|  | ||||
|  | ||||
| # ops handler which computes all the free unbacked symbols for an IR | ||||
| class FreeUnbackedSymbolsOpsHandler(DefaultHandler): | ||||
| class FreeUnbackedSymbolsOpsHandler: | ||||
|     symbols: OrderedSet[sympy.Symbol] | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         self.symbols = OrderedSet() | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|         for a in itertools.chain(args, kwargs.values()): | ||||
|             if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): | ||||
|                 self.symbols |= free_unbacked_symbols(a) | ||||
|     def __getattr__(self, name: str) -> Callable[..., Any]: | ||||
|         def inner(*args: Sequence[Any], **kwargs: Dict[Any, Any]) -> None: | ||||
|             for a in itertools.chain(args, kwargs.values()): | ||||
|                 if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): | ||||
|                     self.symbols |= free_unbacked_symbols(a) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def indirect_indexing( | ||||
|         self, | ||||
| @ -799,12 +791,10 @@ class FreeUnbackedSymbolsOpsHandler(DefaultHandler): | ||||
|         body() | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|  | ||||
|     class _typecheck_FreeUnbackedSymbolsOpsHandler( | ||||
|         FreeUnbackedSymbolsOpsHandler, OpsHandler[None] | ||||
|     ): | ||||
|         pass | ||||
| def _typecheck_FreeUnbackedSymbolsOpsHandler( | ||||
|     h: FreeUnbackedSymbolsOpsHandler, | ||||
| ) -> OpsHandler[None]: | ||||
|     return h | ||||
|  | ||||
|  | ||||
| def extract_free_unbacked_symbols( | ||||
|  | ||||
| @ -21,9 +21,8 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing. | ||||
|  | ||||
| """ | ||||
| import itertools | ||||
| from collections.abc import Sequence | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union | ||||
| from typing import Any, Callable, Literal, Optional, overload, Union | ||||
| from typing_extensions import TypeAlias | ||||
|  | ||||
| import sympy | ||||
| @ -33,7 +32,6 @@ from torch._prims_common import dtype_to_type, is_integer_dtype | ||||
| from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where | ||||
| from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges | ||||
|  | ||||
| from .ops_handler import DefaultHandler, OpsHandler | ||||
| from .sizevars import evaluate_expr | ||||
| from .utils import generate_assert | ||||
| from .virtualized import V | ||||
| @ -187,7 +185,7 @@ class IndexPropVar: | ||||
| IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] | ||||
|  | ||||
|  | ||||
| class IndexPropagation(DefaultHandler): | ||||
| class IndexPropagation: | ||||
|     """Ops wrapper that tries to propagate constant and index_expr values through the computation. | ||||
|  | ||||
|     This aims to maximize the compile time simplification possible, and convert | ||||
| @ -249,19 +247,19 @@ class IndexPropagation(DefaultHandler): | ||||
|     def fallback( | ||||
|         self, | ||||
|         name: Literal["indirect_indexing"], | ||||
|         args: Sequence[Any], | ||||
|         args: tuple[Any, ...], | ||||
|         kwargs: dict[str, Any], | ||||
|     ) -> IndexPropVar: | ||||
|         ... | ||||
|  | ||||
|     @overload | ||||
|     def fallback( | ||||
|         self, name: str, args: Sequence[Any], kwargs: dict[str, Any] | ||||
|         self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] | ||||
|     ) -> IndexPropResult: | ||||
|         ... | ||||
|  | ||||
|     def fallback( | ||||
|         self, name: str, args: Sequence[Any], kwargs: dict[str, Any] | ||||
|         self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] | ||||
|     ) -> IndexPropResult: | ||||
|         # Fallback to the wrapped handler | ||||
|         new_args = [self.unwrap(a) for a in args] | ||||
| @ -269,7 +267,7 @@ class IndexPropagation(DefaultHandler): | ||||
|         return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) | ||||
|  | ||||
|     def propagate_sympy( | ||||
|         self, name: str, args: Sequence[Any], kwargs: dict[str, Any] | ||||
|         self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] | ||||
|     ) -> IndexPropResult: | ||||
|         # Build a new SymPy expression from this ops call | ||||
|         def unwrap(a: Union[Any, IndexPropVar]) -> Any: | ||||
| @ -290,19 +288,22 @@ class IndexPropagation(DefaultHandler): | ||||
|             return self.fallback(name, args, kwargs) | ||||
|         return IndexPropVar.new_symbolic(new_expr) | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|         if not hasattr(SymPyOps, name): | ||||
|             return self.fallback(name, args, kwargs) | ||||
|     def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: | ||||
|         def inner(*args: Any, **kwargs: Any) -> IndexPropResult: | ||||
|             if not hasattr(SymPyOps, name): | ||||
|                 return self.fallback(name, args, kwargs) | ||||
|  | ||||
|         var_arguments = [ | ||||
|             a | ||||
|             for a in itertools.chain(args, kwargs.values()) | ||||
|             if isinstance(a, IndexPropVar) | ||||
|         ] | ||||
|         if not all(v.is_symbolic for v in var_arguments): | ||||
|             return self.fallback(name, args, kwargs) | ||||
|             var_arguments = [ | ||||
|                 a | ||||
|                 for a in itertools.chain(args, kwargs.values()) | ||||
|                 if isinstance(a, IndexPropVar) | ||||
|             ] | ||||
|             if not all(v.is_symbolic for v in var_arguments): | ||||
|                 return self.fallback(name, args, kwargs) | ||||
|  | ||||
|         return self.propagate_sympy(name, args, kwargs) | ||||
|             return self.propagate_sympy(name, args, kwargs) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     def statically_true(self, e): | ||||
|         """ | ||||
| @ -370,9 +371,3 @@ class IndexPropagation(DefaultHandler): | ||||
|             "indirect_indexing", (index, size, check, wrap_neg), {} | ||||
|         ).value | ||||
|         return indirect_var | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|  | ||||
|     class _typecheck_IndexPropagation(IndexPropagation, OpsHandler[Any]): | ||||
|         pass | ||||
|  | ||||
| @ -17,7 +17,6 @@ from torch.utils._sympy.symbol import SymT | ||||
|  | ||||
| from . import config, dependencies | ||||
| from .codegen.common import index_prevent_reordering | ||||
| from .ops_handler import DefaultHandler | ||||
| from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs | ||||
| from .virtualized import ops, V | ||||
|  | ||||
| @ -654,11 +653,11 @@ class LoopBodyBlock: | ||||
|         return copy | ||||
|  | ||||
|  | ||||
| class CountOps(DefaultHandler): | ||||
| class CountOps: | ||||
|     def __init__(self, inner: Any, counts: collections.Counter[str]): | ||||
|         self._inner = inner | ||||
|         self._counts = counts | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|     def __getattr__(self, name): | ||||
|         self._counts[name] += 1 | ||||
|         return getattr(self._inner, name)(*args, **kwargs) | ||||
|         return getattr(self._inner, name) | ||||
|  | ||||
| @ -64,7 +64,6 @@ from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
| from .ops_handler import (  # noqa: F401 | ||||
|     DefaultHandler, | ||||
|     KernelFormatterHandler, | ||||
|     MockHandler, | ||||
|     OpsHandler, | ||||
| @ -275,15 +274,18 @@ class OpsValue: | ||||
|         return ops.bitwise_left_shift(self, n) | ||||
|  | ||||
|  | ||||
| class OpsWrapper(DefaultHandler): | ||||
| class OpsWrapper: | ||||
|     """This wraps any returned IR values into an `OpsValue` instance, so that we | ||||
|     can overload the magic methods for writing mathematical expressions fluently. | ||||
|     """ | ||||
|  | ||||
|     def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: | ||||
|         new_args = [OpsWrapper._unwrap(a) for a in args] | ||||
|         new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} | ||||
|         return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) | ||||
|     def __getattr__(self, name): | ||||
|         def inner(*args, **kwargs): | ||||
|             new_args = [OpsWrapper._unwrap(a) for a in args] | ||||
|             new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} | ||||
|             return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     @staticmethod | ||||
|     def _unwrap(x): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user