mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Fall back non-vectorized op by scalar impl + for loop.
Example code:
```
cpp_fused_igammac_0 = async_compile.cpp_pybinding(['const double*', 'const double*', 'double*'], '''
#include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h"
extern "C"  void kernel(const double* in_ptr0,
                       const double* in_ptr1,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(48L); x0+=static_cast<int64_t>(8L))
        {
            auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
            auto tmp1 = in_ptr1[static_cast<int64_t>(0L)];
            auto tmp2 = at::vec::VectorizedN<double,2>(tmp1);
            auto tmp3 =
            [&]()
            {
                __at_align__ std::array<double, 8> tmpbuf0;
                tmp0.store(tmpbuf0.data(), 8);
                __at_align__ std::array<double, 8> tmpbuf1;
                tmp2.store(tmpbuf1.data(), 8);
                __at_align__ std::array<double, 8> tmpbuf_out;
                for (int i = 0; i < 8; i++)
                {
                    tmpbuf_out[i] = calc_igammac(tmpbuf0[i], tmpbuf1[i]);
                }
                return at::vec::VectorizedN<double, 2>::loadu(tmpbuf_out.data(), 8);
            }
            ()
            ;
            tmp3.store(out_ptr0 + static_cast<int64_t>(x0), 8);
        }
        #pragma omp simd simdlen(4)
        for(int64_t x0=static_cast<int64_t>(48L); x0<static_cast<int64_t>(50L); x0+=static_cast<int64_t>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
            auto tmp1 = in_ptr1[static_cast<int64_t>(0L)];
            auto tmp2 = calc_igammac(tmp0, tmp1);
            out_ptr0[static_cast<int64_t>(x0)] = tmp2;
        }
    }
}
''')
```
`frexp` are difficult to be handled by common `fallback` since it returns two `cse_var` 2ba60a1618/torch/_inductor/codegen/cpp.py (L752-L766)
So we added a special function to do that.
```
cpp_fused_frexp_0 = async_compile.cpp_pybinding(['const double*', 'double*', 'int32_t*'], '''
#include "/tmp/torchinductor_root/z4/cz4j2mmotlx3z2b7u4fbjtdt4x6plhd67ljwzg5bk7ekv4xz6y7q.h"
extern "C"  void kernel(const double* in_ptr0,
                       double* out_ptr0,
                       int32_t* out_ptr1)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(16L); x0+=static_cast<int64_t>(8L))
        {
            auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), 8);
            at::vec::Vectorized<int32_t> tmp1;
            at::vec::VectorizedN<double, 2> tmp2;
            [&]()
            {
                __at_align__ std::array<double, 8> tmpbuf;
                tmp0.store(tmpbuf.data(), 8);
                __at_align__ std::array<int32_t, 8> tmpbuf_exponent;
                __at_align__ std::array<double, 8> tmpbuf_mantissa;
                for (int i = 0; i < 8; i++)
                {
                    tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);
                }
                tmp1 = at::vec::Vectorized<int32_t>::loadu(tmpbuf_exponent.data(), 8);
                tmp2 = at::vec::VectorizedN<double, 2>::loadu(tmpbuf_mantissa.data(), 8);
            }
            ();
            tmp2.store(out_ptr0 + static_cast<int64_t>(x0), 8);
            tmp1.store(out_ptr1 + static_cast<int64_t>(x0), 8);
        }
        #pragma omp simd simdlen(4)
        for(int64_t x0=static_cast<int64_t>(16L); x0<static_cast<int64_t>(20L); x0+=static_cast<int64_t>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
            int32_t tmp1;
            auto tmp2 = std::frexp(tmp0, &tmp1);
            out_ptr0[static_cast<int64_t>(x0)] = tmp2;
            out_ptr1[static_cast<int64_t>(x0)] = tmp1;
        }
    }
}
''')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134569
Approved by: https://github.com/jgong5, https://github.com/jansel
		
	
		
			
				
	
	
		
			2166 lines
		
	
	
		
			74 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2166 lines
		
	
	
		
			74 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| import contextlib
 | |
| import dataclasses
 | |
| import functools
 | |
| import itertools
 | |
| import logging
 | |
| import math
 | |
| import operator
 | |
| import re
 | |
| from enum import auto, Enum
 | |
| from itertools import chain
 | |
| from typing import (
 | |
|     Any,
 | |
|     Callable,
 | |
|     ClassVar,
 | |
|     Dict,
 | |
|     List,
 | |
|     NamedTuple,
 | |
|     Optional,
 | |
|     Tuple,
 | |
|     Union,
 | |
| )
 | |
| 
 | |
| import sympy
 | |
| from sympy.printing.printer import Printer
 | |
| 
 | |
| import torch
 | |
| import torch.fx
 | |
| from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
 | |
| from torch.utils import _pytree as pytree
 | |
| from torch.utils._ordered_set import OrderedSet
 | |
| from torch.utils._sympy.numbers import int_oo
 | |
| from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
 | |
| from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
 | |
| 
 | |
| from .. import config, metrics
 | |
| from ..utils import (
 | |
|     DeferredLineBase,
 | |
|     generate_assert,
 | |
|     IndentedBuffer,
 | |
|     sympy_dot,
 | |
|     sympy_subs,
 | |
|     unique,
 | |
| )
 | |
| from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
 | |
| 
 | |
| 
 | |
| schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
 | |
| 
 | |
| 
 | |
| def data_type_logger(msg):
 | |
|     if schedule_log.isEnabledFor(logging.DEBUG):
 | |
|         schedule_log.debug("Data type propagation: %s", msg)
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class WorkspaceArg:
 | |
|     """A temporary buffer used for a single kernel, then discarded.
 | |
| 
 | |
|     Not registered as a traditional buffer since there are no users,
 | |
|     so it would be dead code eliminated.
 | |
|     """
 | |
| 
 | |
|     nbytes: sympy.Expr
 | |
|     zero_fill: bool
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class TensorArg:
 | |
|     name: str
 | |
|     buffer: str
 | |
|     dtype: torch.dtype
 | |
|     offset: sympy.Expr = sympy.Integer(0)  # c++ only
 | |
|     alias_of: Optional[str] = None  # halide only
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class SizeArg:
 | |
|     name: str
 | |
|     expr: sympy.Expr
 | |
| 
 | |
|     @property
 | |
|     def alias_of(self):
 | |
|         return None
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class DeviceCodegen:
 | |
|     scheduling: Any
 | |
|     wrapper_codegen: type
 | |
|     cpp_wrapper_codegen: type = type(None)
 | |
| 
 | |
| 
 | |
| KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
 | |
| 
 | |
| device_codegens: Dict[str, DeviceCodegen] = {}
 | |
| 
 | |
| 
 | |
| class DeviceOpOverrides:
 | |
|     def import_get_raw_stream_as(self, name):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def set_device(self, device_idx):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def synchronize(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def device_guard(self, device_idx):
 | |
|         raise NotImplementedError
 | |
| 
 | |
| 
 | |
| device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
 | |
| 
 | |
| 
 | |
| # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
 | |
| # For any new backend looking to integrate with Inductor, customization of these two main
 | |
| # parts are necessary to generate its specific code.
 | |
| #
 | |
| # Kernel code generation is determined by different Scheduling. Consequently, a new
 | |
| # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
 | |
| # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
 | |
| #
 | |
| # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
 | |
| # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
 | |
| # and override specific member functions to create backend-specific Python wrapper code.
 | |
| #
 | |
| # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
 | |
| # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
 | |
| # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
 | |
| # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
 | |
| # register_backend_for_device, to equip a new backend at runtime.
 | |
| #
 | |
| # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
 | |
| # This backend can be used as a reference:
 | |
| # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
 | |
| def register_backend_for_device(
 | |
|     device: str,
 | |
|     device_scheduling: Any,
 | |
|     device_wrapper_codegen: type,
 | |
|     device_cpp_wrapper_codegen: type = type(None),
 | |
| ):
 | |
|     device_codegens[device] = DeviceCodegen(
 | |
|         device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
 | |
|     )
 | |
| 
 | |
| 
 | |
| class BackendFeature(Enum):
 | |
|     FOREACH = auto()
 | |
|     BUCKETIZE = auto()
 | |
|     INPLACE_BUFFERS = auto()
 | |
|     MASKED_SCATTER_WITH_INDEX = auto()
 | |
|     SCAN = auto()
 | |
|     SORT = auto()
 | |
|     TUPLE_REDUCTION = auto()
 | |
|     PREFER_STORE_LOOP_ORDER = auto()
 | |
|     TRITON_TEMPLATES = auto()
 | |
|     REDUCE_TO_SINGLE_ELEMENT = auto()
 | |
| 
 | |
| 
 | |
| def get_backend_features(device: Union[torch.device, str]):
 | |
|     init_backend_registration()
 | |
|     if isinstance(device, torch.device):
 | |
|         device_type = device.type
 | |
|     else:
 | |
|         assert isinstance(device, str)
 | |
|         device_type = device
 | |
|         device = torch.device(device_type)
 | |
|     scheduling = get_scheduling_for_device(device_type)
 | |
|     return scheduling(None).get_backend_features(device)
 | |
| 
 | |
| 
 | |
| def has_backend_feature(device, feature):
 | |
|     """See also V.graph.has_feature"""
 | |
|     assert isinstance(feature, BackendFeature)
 | |
|     return feature in get_backend_features(device)
 | |
| 
 | |
| 
 | |
| def get_scheduling_for_device(device: str):
 | |
|     return device_codegens[device].scheduling if device in device_codegens else None
 | |
| 
 | |
| 
 | |
| def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
 | |
|     if device in device_codegens:
 | |
|         wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
 | |
|         return (
 | |
|             wrapper_codegen_obj.cpp_wrapper_codegen
 | |
|             if cpp_wrapper
 | |
|             else wrapper_codegen_obj.wrapper_codegen
 | |
|         )
 | |
|     else:
 | |
|         return None
 | |
| 
 | |
| 
 | |
| @functools.lru_cache(None)
 | |
| def init_backend_registration():
 | |
|     from .cpp import CppScheduling
 | |
|     from .cpp_wrapper_cpu import CppWrapperCpu
 | |
|     from .cpp_wrapper_cuda import CppWrapperCuda
 | |
|     from .cuda_combined_scheduling import CUDACombinedScheduling
 | |
|     from .halide import HalideScheduling
 | |
|     from .triton import TritonScheduling
 | |
|     from .wrapper import WrapperCodeGen
 | |
| 
 | |
|     if get_scheduling_for_device("cpu") is None:
 | |
|         cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
 | |
|         register_backend_for_device(
 | |
|             "cpu",
 | |
|             lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
 | |
|             WrapperCodeGen,
 | |
|             CppWrapperCpu,
 | |
|         )
 | |
| 
 | |
|     if get_scheduling_for_device("cuda") is None:
 | |
|         # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
 | |
|         cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
 | |
|         register_backend_for_device(
 | |
|             "cuda",
 | |
|             lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
 | |
|             WrapperCodeGen,
 | |
|             CppWrapperCuda,
 | |
|         )
 | |
| 
 | |
|     if get_scheduling_for_device("xpu") is None:
 | |
|         register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
 | |
| 
 | |
|     private_backend = torch._C._get_privateuse1_backend_name()
 | |
|     if (
 | |
|         private_backend != "privateuseone"
 | |
|         and get_scheduling_for_device(private_backend) is None
 | |
|     ):
 | |
|         from torch.utils.backend_registration import _get_custom_mod_func
 | |
| 
 | |
|         try:
 | |
|             device_scheduling = _get_custom_mod_func("Scheduling")
 | |
|             wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
 | |
|             cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
 | |
|             if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
 | |
|                 register_backend_for_device(
 | |
|                     private_backend,
 | |
|                     device_scheduling,
 | |
|                     wrapper_codegen,
 | |
|                     cpp_wrapper_codegen,
 | |
|                 )
 | |
|         except RuntimeError:
 | |
|             pass
 | |
| 
 | |
| 
 | |
| def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
 | |
|     from ..ir import FlexibleLayout
 | |
| 
 | |
|     # added contiguous index prevents reordering
 | |
|     return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
 | |
| 
 | |
| 
 | |
| def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
 | |
|     device_op_overrides_dict[device] = device_op_overrides
 | |
| 
 | |
| 
 | |
| def get_device_op_overrides(device: str):
 | |
|     assert isinstance(device, str)
 | |
| 
 | |
|     if not device_op_overrides_dict.keys():
 | |
|         from .cuda import device_op_overrides  # noqa: F401
 | |
|         from .xpu import device_op_overrides as xpu_op_overrides  # noqa: F401
 | |
| 
 | |
|     if device in device_op_overrides_dict.keys():
 | |
|         return device_op_overrides_dict[device]
 | |
| 
 | |
| 
 | |
| @functools.lru_cache(None)
 | |
| def boolean_ops():
 | |
|     return (
 | |
|         "isinf",
 | |
|         "isnan",
 | |
|         "logical_not",
 | |
|         "signbit",
 | |
|         "le",
 | |
|         "lt",
 | |
|         "ge",
 | |
|         "gt",
 | |
|         "eq",
 | |
|         "ne",
 | |
|     )
 | |
| 
 | |
| 
 | |
| DTYPE_TO_COMPUTATION_DTYPE = {
 | |
|     torch.bfloat16: torch.float,
 | |
|     torch.float16: torch.float,
 | |
|     **{
 | |
|         dtype: dtype
 | |
|         for dtype in [
 | |
|             torch.bool,
 | |
|             torch.float32,
 | |
|             torch.float64,
 | |
|             torch.int8,
 | |
|             torch.int16,
 | |
|             torch.int32,
 | |
|             torch.int64,
 | |
|             torch.uint8,
 | |
|             torch.uint16,
 | |
|             torch.uint32,
 | |
|             torch.uint64,
 | |
|         ]
 | |
|     },
 | |
| }
 | |
| 
 | |
| 
 | |
| def deduce_output_dtype_by_name(
 | |
|     op_name: str,
 | |
|     *args,
 | |
|     **kwargs,
 | |
| ) -> Optional[torch.dtype]:
 | |
|     """
 | |
|     Given op name and a list of input dtypes, deduce the output dtype
 | |
|     """
 | |
|     if op_name in boolean_ops():
 | |
|         return torch.bool
 | |
|     elif op_name in (
 | |
|         "to_dtype",
 | |
|         "index_expr",
 | |
|     ):
 | |
|         return kwargs["dtype"] if "dtype" in kwargs else args[-1]
 | |
|     elif op_name in (
 | |
|         "rand",
 | |
|         "randn",
 | |
|     ):
 | |
|         return torch.float
 | |
|     elif op_name in (
 | |
|         "get_index",
 | |
|         "randint64",
 | |
|         "load_seed",
 | |
|     ):
 | |
|         return torch.int64
 | |
|     elif op_name == "reduction":
 | |
|         return kwargs["dtype"] if "dtype" in kwargs else args[1]
 | |
|     elif op_name == "constant":
 | |
|         dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
 | |
|         return DTYPE_TO_COMPUTATION_DTYPE[dtype]  # type: ignore[index]
 | |
|     elif op_name in (
 | |
|         "load",
 | |
|         "store",
 | |
|         "store_reduction",
 | |
|     ):
 | |
|         buf_name = args[1]
 | |
|         return V.graph.get_dtype(buf_name)  # type: ignore[arg-type]
 | |
|     return None
 | |
| 
 | |
| 
 | |
| class DataTypePropagation:
 | |
|     def __init__(self, body) -> None:
 | |
|         self.body = body
 | |
|         self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
 | |
|             "root": body.root_block.graph
 | |
|         }
 | |
|         for k, v in body.subblocks.items():
 | |
|             self.graphs[k] = v.graph
 | |
| 
 | |
|     def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
 | |
|         inputs = node.all_input_nodes
 | |
|         input_nodes = [
 | |
|             n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
 | |
|         ]
 | |
|         if len(input_nodes) == 0:
 | |
|             return None
 | |
| 
 | |
|         all_input_nodes_propagated = all(
 | |
|             OptimizationContext.key in n.meta
 | |
|             and n.meta[OptimizationContext.key].dtype is not None
 | |
|             for n in input_nodes
 | |
|         )
 | |
|         if not all_input_nodes_propagated:
 | |
|             return None
 | |
| 
 | |
|         return functools.reduce(
 | |
|             torch.promote_types,
 | |
|             [n.meta[OptimizationContext.key].dtype for n in input_nodes],
 | |
|         )
 | |
| 
 | |
|     def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
 | |
|         sub_graph = self.graphs[node.target]
 | |
|         dtype = self.propagate_graph(sub_graph)
 | |
|         assert dtype
 | |
|         return dtype
 | |
| 
 | |
|     def deduce_node_dtype(self, node: torch.fx.Node):
 | |
|         if node.op == "placeholder":
 | |
|             return None
 | |
| 
 | |
|         if node.target == "output" and len(node.args) != 1:
 | |
|             # we can infer output node if it only have 1 arg
 | |
|             return None
 | |
| 
 | |
|         if node.target == operator.getitem:
 | |
|             return self.deduce_node_dtype(node.args[0])  # type: ignore[arg-type]
 | |
| 
 | |
|         assert isinstance(node.target, str)
 | |
| 
 | |
|         if node.target.startswith("masked_subblock"):
 | |
|             return self.deduce_node_dtype_by_subgraph(node)
 | |
| 
 | |
|         if (
 | |
|             output_dtype := deduce_output_dtype_by_name(
 | |
|                 node.target,
 | |
|                 *node.args,
 | |
|                 **node.kwargs,
 | |
|             )
 | |
|         ) is not None:
 | |
|             return output_dtype
 | |
| 
 | |
|         return self.deduce_node_dtype_by_inputs(node)
 | |
| 
 | |
|     def propagate_graph(self, graph: torch.fx.Graph):
 | |
|         assert graph.nodes
 | |
|         graph_dtype = None
 | |
|         # For masked_subblock, we use output's dtype to represent
 | |
|         # the dtype of this subgraph. For other cases, graph_dtype
 | |
|         # might be None
 | |
|         for node in graph.nodes:
 | |
|             if OptimizationContext.key in node.meta:
 | |
|                 opt_ctx = node.meta[OptimizationContext.key]
 | |
|             else:
 | |
|                 opt_ctx = OptimizationContext()
 | |
| 
 | |
|             opt_ctx.dtype = self.deduce_node_dtype(node)
 | |
|             node.meta[OptimizationContext.key] = opt_ctx
 | |
|             if node.target == "output":
 | |
|                 graph_dtype = opt_ctx.dtype
 | |
|         return graph_dtype
 | |
| 
 | |
|     def propagate(self):
 | |
|         self.propagate_graph(self.graphs["root"])
 | |
| 
 | |
|     @classmethod
 | |
|     def propagate_loopbody(cls, body):
 | |
|         return cls(body).propagate()
 | |
| 
 | |
|     @classmethod
 | |
|     def propagate_scheduler_node(cls, node):
 | |
|         from ..ir import LoopBody
 | |
|         from ..scheduler import SchedulerNode
 | |
| 
 | |
|         assert isinstance(node, SchedulerNode)
 | |
|         assert isinstance(node._body, LoopBody)
 | |
|         DataTypePropagation.propagate_loopbody(node._body)
 | |
| 
 | |
| 
 | |
| # This printer contains rules that are supposed to be generic for both C/C++ and
 | |
| # Python
 | |
| class ExprPrinter(Printer):
 | |
|     @staticmethod
 | |
|     def paren(string):
 | |
|         def all_in_parens(string):
 | |
|             if string[0] != "(" or len(string) < 2:
 | |
|                 return False
 | |
|             count = 1
 | |
|             for i, char in enumerate(string[1:]):
 | |
|                 if char == "(":
 | |
|                     count += 1
 | |
|                 elif char == ")":
 | |
|                     count -= 1
 | |
|                 if count == 0 and i != len(string) - 2:
 | |
|                     return False
 | |
|             assert count == 0
 | |
|             return True
 | |
| 
 | |
|         if (
 | |
|             isinstance(string, CSEVariable)
 | |
|             or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
 | |
|             or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
 | |
|             or string == ""
 | |
|         ):
 | |
|             return string
 | |
|         # don't put extra parens for strings that are already wrapped in parens
 | |
|         if all_in_parens(string):
 | |
|             return string
 | |
|         return f"({string})"
 | |
| 
 | |
|     def _print_Relational(self, expr):
 | |
|         return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     def _print_Mul(self, expr):
 | |
|         return "*".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     def _print_Add(self, expr):
 | |
|         return " + ".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     # NB: this is OK to put here, because Mod is only defined for positive
 | |
|     # numbers, and so across C/Python its behavior is consistent
 | |
|     def _print_Mod(self, expr):
 | |
|         return " % ".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     def _print_FloatTrueDiv(self, expr):
 | |
|         lhs, rhs = expr.args
 | |
|         return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
 | |
| 
 | |
|     def _print_CleanDiv(self, expr):
 | |
|         return self._print_FloorDiv(expr)
 | |
| 
 | |
|     def _print_Identity(self, expr):
 | |
|         return self._print(expr.args[0])
 | |
| 
 | |
|     def _print_GreaterThan(self, expr):
 | |
|         # GreaterThan:          >=
 | |
|         # StrictlyGreaterThan:  >
 | |
|         # Go figure...
 | |
|         return " >= ".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     # NB: The C implementation is injected into codegen at
 | |
|     # torch/_inductor/codegen/wrapper.py
 | |
|     def _print_align(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"align({self._print(expr.args[0])})"
 | |
| 
 | |
|     # This must be implemented because sympy will collect x * x into Pow(x, 2), without
 | |
|     # any explicit intervention.  We print it just like x * x, notably, we
 | |
|     # never generate sympy.Pow with floats.
 | |
|     #
 | |
|     # NB: this pow by natural, you should never have used builtin sympy.pow
 | |
|     # for FloatPow, and a symbolic exponent should be PowByNatural.  These
 | |
|     # means exp is guaranteed to be integer.
 | |
|     def _print_Pow(self, expr):
 | |
|         base, exp = expr.args
 | |
|         base = self._print(base)
 | |
|         assert exp == int(exp), exp
 | |
|         exp = int(exp)
 | |
|         assert exp >= 0
 | |
|         if exp > 0:
 | |
|             return "*".join([self.paren(base)] * exp)
 | |
|         else:  # exp == 0
 | |
|             return "1"
 | |
| 
 | |
|     # Explicit NotImplemented functions are to prevent default sympy printing
 | |
|     # behavior, which will just barf out ToFloat(...) to your IR.  The error
 | |
|     # message is better here because it tells you which printer class it needs
 | |
|     # to go in.
 | |
| 
 | |
|     def _print_ToFloat(self, expr):
 | |
|         raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_Infinity(self, expr):
 | |
|         raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_NegativeInfinity(self, expr):
 | |
|         raise NotImplementedError(
 | |
|             f"_print_NegativeInfinity not implemented for {type(self)}"
 | |
|         )
 | |
| 
 | |
|     def _print_FloorDiv(self, expr):
 | |
|         raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_PythonMod(self, expr):
 | |
|         raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_IntTrueDiv(self, expr):
 | |
|         raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_PowByNatural(self, expr):
 | |
|         raise NotImplementedError(
 | |
|             f"_print_PowByNatural not implemented for {type(self)}"
 | |
|         )
 | |
| 
 | |
|     def _print_FloatPow(self, expr):
 | |
|         raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_TruncToInt(self, expr):
 | |
|         raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_RoundToInt(self, expr):
 | |
|         raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
 | |
| 
 | |
|     def _print_RoundDecimal(self, expr):
 | |
|         raise NotImplementedError(
 | |
|             f"_print_RoundDecimal not implemented for {type(self)}"
 | |
|         )
 | |
| 
 | |
|     # NB: Some float operations are INTENTIONALLY not implemented for
 | |
|     # printers.  You can implement them as a quick unblock, but it is better
 | |
|     # to ask yourself why we haven't done this computation in the Tensor
 | |
|     # universe instead
 | |
| 
 | |
|     def _print_TruncToFloat(self, expr):
 | |
|         raise NotImplementedError(
 | |
|             f"_print_TruncToFloat not implemented for {type(self)}"
 | |
|         )
 | |
| 
 | |
|     def doprint(self, expr, *, simplify: bool = True):
 | |
|         # TODO: why are people passing strings to the printer here :think:
 | |
|         if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
 | |
|             expr = V.graph.sizevars.simplify(expr)
 | |
|         return super().doprint(expr)
 | |
| 
 | |
| 
 | |
| class PythonPrinter(ExprPrinter):
 | |
|     def _print_ToFloat(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"float({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_ModularIndexing(self, expr):
 | |
|         x, div, mod = expr.args
 | |
|         x = self.paren(self.doprint(x))
 | |
|         div = self.paren(self.doprint(div))
 | |
|         mod = self.paren(self.doprint(mod))
 | |
|         if div != "1":
 | |
|             x = f"({x} // {div})"
 | |
|         return f"{x} % {mod}"
 | |
| 
 | |
|     def _print_Infinity(self, expr):
 | |
|         return "math.inf"
 | |
| 
 | |
|     def _print_NegativeInfinity(self, expr):
 | |
|         return "-math.inf"
 | |
| 
 | |
|     # WARNING: this is dangerous for Triton, which has C-style modulus
 | |
|     def _print_PythonMod(self, expr):
 | |
|         return " % ".join(map(self.paren, map(self._print, expr.args)))
 | |
| 
 | |
|     # WARNING: this is dangerous for Triton, which has C-style modulus
 | |
|     def _print_FloorDiv(self, expr):
 | |
|         x, div = expr.args
 | |
|         x = self.paren(self.doprint(x))
 | |
|         div = self.paren(self.doprint(div))
 | |
|         return f"({x} // {div})"
 | |
| 
 | |
|     # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
 | |
|     # does a special algorithm
 | |
|     def _print_IntTrueDiv(self, expr):
 | |
|         lhs, rhs = expr.args
 | |
|         return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
 | |
| 
 | |
|     def _helper_sqrt(self, expr):
 | |
|         return f"math.sqrt({self._print(expr)})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_sqrt(self, expr):
 | |
|         return self._helper_sqrt(expr.args[0])
 | |
| 
 | |
|     def _print_FloatPow(self, expr):
 | |
|         base, exp = expr.args
 | |
|         return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
 | |
| 
 | |
|     # TODO: Not sure this works with Triton, even when base/exp are integral
 | |
|     def _print_PowByNatural(self, expr):
 | |
|         base, exp = expr.args
 | |
|         return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
 | |
| 
 | |
|     def _print_floor(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.floor({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_FloorToInt(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.floor({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_TruncToInt(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         # This also could have been int(), they'll do the same thing for float
 | |
|         return f"math.trunc({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_ceiling(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.ceil({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_CeilToInt(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.ceil({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_Abs(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"abs({self._print(expr.args[0])})"
 | |
| 
 | |
|     # NB: It's expected that we've made explicit any promotion in the sympy
 | |
|     # expression, so it doesn't matter that Python max/min doesn't perform
 | |
|     # promotion
 | |
|     def _print_Max(self, expr):
 | |
|         assert len(expr.args) >= 2
 | |
|         return f"max({', '.join(map(self._print, expr.args))})"
 | |
| 
 | |
|     def _print_Min(self, expr):
 | |
|         assert len(expr.args) >= 2
 | |
|         return f"min({', '.join(map(self._print, expr.args))})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_cos(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.cos({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_cosh(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.cosh({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_acos(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.acos({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_sin(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.sin({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_sinh(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.sinh({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_asin(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.asin({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_tan(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.tan({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_tanh(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.tanh({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_OpaqueUnaryFn_atan(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"math.atan({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_RoundToInt(self, expr):
 | |
|         assert len(expr.args) == 1
 | |
|         return f"round({self._print(expr.args[0])})"
 | |
| 
 | |
|     def _print_RoundDecimal(self, expr):
 | |
|         assert len(expr.args) == 2
 | |
|         number, ndigits = expr.args
 | |
|         assert isinstance(ndigits, sympy.Integer)
 | |
|         return f"round({self._print(number)}, {ndigits})"
 | |
| 
 | |
| 
 | |
| class OpOverrides:
 | |
|     def __init__(self, parent):
 | |
|         super().__init__()
 | |
|         self._parent = parent
 | |
| 
 | |
|     def __getattr__(self, item):
 | |
|         return getattr(self._parent, item)
 | |
| 
 | |
|     @staticmethod
 | |
|     def identity(value):
 | |
|         # used to trigger cse
 | |
|         return value
 | |
| 
 | |
|     @staticmethod
 | |
|     def constant(value, dtype):
 | |
|         return repr(value)
 | |
| 
 | |
|     @staticmethod
 | |
|     def reciprocal(x):
 | |
|         return ops.truediv(ops.constant(1, torch.int32), x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def square(x):
 | |
|         return ops.mul(x, x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def erfc(x):
 | |
|         return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
 | |
| 
 | |
|     @staticmethod
 | |
|     def erfcx(x):
 | |
|         return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
 | |
| 
 | |
|     @staticmethod
 | |
|     def expm1(x):
 | |
|         return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
 | |
| 
 | |
|     @staticmethod
 | |
|     def log10(x):
 | |
|         return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
 | |
| 
 | |
|     @staticmethod
 | |
|     def log2(x):
 | |
|         return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
 | |
| 
 | |
|     @staticmethod
 | |
|     def exp2(x):
 | |
|         return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
 | |
| 
 | |
|     @staticmethod
 | |
|     def log1p(x):
 | |
|         return ops.log(ops.add(x, ops.constant(1, torch.int32)))
 | |
| 
 | |
|     @staticmethod
 | |
|     def sigmoid(x):
 | |
|         one = ops.constant(1, torch.int32)
 | |
|         return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_sigmoid(x):
 | |
|         one = ops.constant(1, torch.int32)
 | |
|         return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
 | |
| 
 | |
|     @staticmethod
 | |
|     def relu(x):
 | |
|         return ops.maximum(x, ops.constant(0, torch.int32))
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_abs(x):
 | |
|         return ops.abs(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_sqrt(x):
 | |
|         return ops.sqrt(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_cos(x):
 | |
|         return ops.cos(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_sin(x):
 | |
|         return ops.sin(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_log(x):
 | |
|         return ops.log(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def libdevice_exp(x):
 | |
|         return ops.exp(x)
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_not(x):
 | |
|         return f"~{ExprPrinter.paren(x)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def logical_not(a):
 | |
|         return f"{ExprPrinter.paren(a)} == 0"
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_and(x, y):
 | |
|         return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_or(x, y):
 | |
|         return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_xor(x, y):
 | |
|         return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_left_shift(x, y):
 | |
|         return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def bitwise_right_shift(x, y):
 | |
|         return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
 | |
| 
 | |
|     @staticmethod
 | |
|     def remainder(a, b):
 | |
|         r = ops.mod(a, b)
 | |
|         cond = ops.and_(
 | |
|             ops.ne(r, ops.constant(0, torch.int32)),
 | |
|             ops.ne(ops.signbit(r), ops.signbit(b)),
 | |
|         )
 | |
|         return ops.where(cond, ops.add(r, b), r)
 | |
| 
 | |
|     @staticmethod
 | |
|     def trunc_to_int(a, dtype):
 | |
|         return ops.to_dtype(ops.trunc(a), dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def floor_to_int(a, dtype):
 | |
|         return ops.to_dtype(ops.floor(a), dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def ceil_to_int(a, dtype):
 | |
|         return ops.to_dtype(ops.ceil(a), dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def round_to_int(a, dtype):
 | |
|         return ops.to_dtype(ops.round(a), dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def int_truediv(a, b):
 | |
|         # TODO: this is wrong
 | |
|         # TODO: an easy bandaid is to generate runtime asserts that it's
 | |
|         # <= 2**53, which is when this equation is correct
 | |
|         return ops.truediv(a, b)
 | |
| 
 | |
|     @staticmethod
 | |
|     def load_seed(name, offset):
 | |
|         return ops.load(name, sympy.Integer(offset))
 | |
| 
 | |
|     @classmethod
 | |
|     def _initialize_pointwise_overrides(cls, target):
 | |
|         assert target in {"triton", "cpp", "cppvec"}, target
 | |
| 
 | |
|         for funcname, data in pointwise_overrides_data.items():
 | |
|             impl = getattr(data, target)
 | |
|             if impl is None:
 | |
|                 continue
 | |
|             setattr(cls, funcname, staticmethod(impl))
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class OverridesData:
 | |
|     name: str
 | |
|     cpp: Callable[..., str]
 | |
|     # None when not impl in libdevice/triton
 | |
|     triton: Optional[Callable[..., str]] = None
 | |
|     # None when not impl in aten/.../vec
 | |
|     cppvec: Optional[Callable[..., str]] = None
 | |
|     type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
 | |
|         ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
 | |
|     )
 | |
| 
 | |
| 
 | |
| # NB: if you add a new special function, don't forget to update
 | |
| # torch._inductor.ops_handler too
 | |
| pointwise_overrides_data: Dict[str, OverridesData] = dict(
 | |
|     airy_ai=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"airy_ai_forward({x})",
 | |
|         name="special_airy_ai",
 | |
|     ),
 | |
|     bessel_j0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"bessel_j0_forward({x})",
 | |
|         triton=lambda x: f"libdevice.j0({x})",
 | |
|         name="special_bessel_j0",
 | |
|     ),
 | |
|     bessel_j1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"bessel_j1_forward({x})",
 | |
|         triton=lambda x: f"libdevice.j1({x})",
 | |
|         name="special_bessel_j1",
 | |
|     ),
 | |
|     bessel_y0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"bessel_y0_forward({x})",
 | |
|         triton=lambda x: f"libdevice.y0({x})",
 | |
|         name="special_bessel_y0",
 | |
|     ),
 | |
|     bessel_y1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"bessel_y1_forward({x})",
 | |
|         triton=lambda x: f"libdevice.y1({x})",
 | |
|         name="special_bessel_y1",
 | |
|     ),
 | |
|     digamma=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_digamma({x})",
 | |
|         cppvec=lambda x: f"{x}.digamma()",
 | |
|         name="digamma",
 | |
|     ),
 | |
|     # no cpp nor triton implementation for entr, it is defined as decomposition
 | |
|     # erf, erfc
 | |
|     erfcx=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_erfcx({x})",
 | |
|         triton=lambda x: f"libdevice.erfcx({x})",
 | |
|         name="special_erfcx",
 | |
|     ),
 | |
|     fma=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
 | |
|         cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
 | |
|         triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
 | |
|         name="fma",
 | |
|     ),
 | |
|     # erfinv, exp2, expit, gammaln
 | |
|     igamma=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"calc_igamma({x}, {y})",
 | |
|         name="igamma",
 | |
|     ),
 | |
|     igammac=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"calc_igammac({x}, {y})",
 | |
|         name="igammac",
 | |
|     ),
 | |
|     gammainc=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"calc_igamma({x}, {y})",
 | |
|         name="special_gammainc",
 | |
|     ),
 | |
|     gammaincc=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"calc_igammac({x}, {y})",
 | |
|         name="special_gammaincc",
 | |
|     ),
 | |
|     i0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_i0({x})",
 | |
|         triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
 | |
|         cppvec=lambda x: f"{x}.i0()",
 | |
|         name="i0",
 | |
|     ),
 | |
|     i0e=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_i0e({x})",
 | |
|         cppvec=lambda x: f"{x}.i0e()",
 | |
|         name="special_i0e",
 | |
|     ),
 | |
|     i1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_i1({x})",
 | |
|         triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
 | |
|         name="special_i1",
 | |
|     ),
 | |
|     i1e=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_i1e({x})",
 | |
|         name="special_i1e",
 | |
|     ),
 | |
|     log_ndtr=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_log_ndtr({x})",
 | |
|         name="special_log_ndtr",
 | |
|     ),
 | |
|     # logit
 | |
|     modified_bessel_i0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"modified_bessel_i0_forward({x})",
 | |
|         triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
 | |
|         name="special_modified_bessel_i0",
 | |
|     ),
 | |
|     modified_bessel_i1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"modified_bessel_i1_forward({x})",
 | |
|         triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
 | |
|         name="special_modified_bessel_i1",
 | |
|     ),
 | |
|     modified_bessel_k0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"modified_bessel_k0_forward({x})",
 | |
|         name="special_modified_bessel_k0",
 | |
|     ),
 | |
|     modified_bessel_k1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"modified_bessel_k1_forward({x})",
 | |
|         name="special_modified_bessel_k1",
 | |
|     ),
 | |
|     # multigamma
 | |
|     ndtr=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_ndtr({x})",
 | |
|         name="special_ndtr",
 | |
|     ),
 | |
|     ndtri=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"calc_ndtri({x})",
 | |
|         name="special_ndtri",
 | |
|     ),
 | |
|     polygamma=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"calc_polygamma({y}, {x})",
 | |
|         name="polygamma",
 | |
|     ),
 | |
|     # psi - alias to digamma
 | |
|     # round
 | |
|     scaled_modified_bessel_k0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
 | |
|         name="special_scaled_modified_bessel_k0",
 | |
|     ),
 | |
|     scaled_modified_bessel_k1=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
 | |
|         name="special_scaled_modified_bessel_k1",
 | |
|     ),
 | |
|     # sinc
 | |
|     spherical_bessel_j0=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x: f"spherical_bessel_j0_forward({x})",
 | |
|         name="special_spherical_bessel_j0",
 | |
|     ),
 | |
|     zeta=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"zeta({x}, {y})",
 | |
|         name="special_zeta",
 | |
|     ),
 | |
|     chebyshev_polynomial_t=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
 | |
|         name="special_chebyshev_polynomial_t",
 | |
|     ),
 | |
|     chebyshev_polynomial_u=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
 | |
|         name="special_chebyshev_polynomial_u",
 | |
|     ),
 | |
|     chebyshev_polynomial_v=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
 | |
|         name="special_chebyshev_polynomial_v",
 | |
|     ),
 | |
|     chebyshev_polynomial_w=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
 | |
|         name="special_chebyshev_polynomial_w",
 | |
|     ),
 | |
|     legendre_polynomial_p=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
 | |
|         name="special_legendre_polynomial_p",
 | |
|     ),
 | |
|     shifted_chebyshev_polynomial_t=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
 | |
|         name="special_shifted_chebyshev_polynomial_t",
 | |
|     ),
 | |
|     shifted_chebyshev_polynomial_u=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
 | |
|         name="special_shifted_chebyshev_polynomial_u",
 | |
|     ),
 | |
|     shifted_chebyshev_polynomial_v=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
 | |
|         name="special_shifted_chebyshev_polynomial_v",
 | |
|     ),
 | |
|     shifted_chebyshev_polynomial_w=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
 | |
|         name="special_shifted_chebyshev_polynomial_w",
 | |
|     ),
 | |
|     hermite_polynomial_h=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
 | |
|         name="special_hermite_polynomial_h",
 | |
|     ),
 | |
|     hermite_polynomial_he=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
 | |
|         name="special_hermite_polynomial_he",
 | |
|     ),
 | |
|     laguerre_polynomial_l=OverridesData(
 | |
|         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|         cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
 | |
|         name="special_laguerre_polynomial_l",
 | |
|     ),
 | |
| )
 | |
| 
 | |
| 
 | |
| # Use mypy to check protocol implemented correctly
 | |
| def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
 | |
|     return h
 | |
| 
 | |
| 
 | |
| class DeferredLine(DeferredLineBase):
 | |
|     """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
 | |
| 
 | |
|     def __init__(self, name, line):
 | |
|         super().__init__(line)
 | |
|         self.name = name
 | |
|         assert not isinstance(line, DeferredLineBase)
 | |
| 
 | |
|     def __call__(self):
 | |
|         if all(
 | |
|             self.name not in x
 | |
|             for x in (
 | |
|                 V.graph.removed_buffers,
 | |
|                 V.kernel.removed_buffers,
 | |
|                 V.graph.inplaced_to_remove,
 | |
|                 V.kernel.inplaced_to_remove,
 | |
|             )
 | |
|         ):
 | |
|             return self.line
 | |
|         return None
 | |
| 
 | |
|     def _new_line(self, line):
 | |
|         return DeferredLine(self.name, line)
 | |
| 
 | |
| 
 | |
| class BracesBuffer(IndentedBuffer):
 | |
|     def indent(self, offset=1):
 | |
|         @contextlib.contextmanager
 | |
|         def ctx():
 | |
|             for _ in range(offset):
 | |
|                 self.writeline("{")
 | |
|                 self._indent += 1
 | |
|             for _ in range(-offset):
 | |
|                 self._indent -= 1
 | |
|                 self.writeline("}")
 | |
|             yield
 | |
|             for _ in range(-offset):
 | |
|                 self.writeline("{")
 | |
|                 self._indent += 1
 | |
|             for _ in range(offset):
 | |
|                 self._indent -= 1
 | |
|                 self.writeline("}")
 | |
| 
 | |
|         return ctx()
 | |
| 
 | |
| 
 | |
| class InplacedBuffer(NamedTuple):
 | |
|     inner_name: str
 | |
|     other_names: List[str]
 | |
| 
 | |
| 
 | |
| class KernelArgs:
 | |
|     @staticmethod
 | |
|     def _lookup(prefix, odict, name):
 | |
|         assert isinstance(name, (str, sympy.Symbol))
 | |
|         if name not in odict:
 | |
|             odict[name] = f"{prefix}{len(odict)}"
 | |
|         return odict[name]
 | |
| 
 | |
|     def __init__(self, sizevars=None):
 | |
|         self.input_buffers = {}
 | |
|         self.output_buffers = {}
 | |
|         self.inplace_buffers = {}
 | |
|         self.sizevars = sizevars or {}
 | |
|         self.workspace_arg = None
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return "KernelArgs({})".format(
 | |
|             ", ".join(
 | |
|                 map(
 | |
|                     repr,
 | |
|                     [
 | |
|                         self.input_buffers,
 | |
|                         self.output_buffers,
 | |
|                         self.inplace_buffers,
 | |
|                         self.sizevars,
 | |
|                     ],
 | |
|                 )
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def _buffer_is_marked_removed(self, name):
 | |
|         return isinstance(name, str) and name.startswith("REMOVED")
 | |
| 
 | |
|     def input(self, name):
 | |
|         if V.graph.scheduler:
 | |
|             name = V.graph.scheduler.mutation_real_name.get(name, name)
 | |
|         assert name not in V.graph.removed_buffers, name
 | |
|         if name in self.output_buffers:
 | |
|             return self.output_buffers[name]
 | |
|         if name in self.inplace_buffers:
 | |
|             return self.inplace_buffers[name].inner_name
 | |
|         if name.startswith("seed"):
 | |
|             return self._lookup("seed", self.input_buffers, name)
 | |
|         return self._lookup("in_ptr", self.input_buffers, name)
 | |
| 
 | |
|     def output(self, name):
 | |
|         if V.graph.scheduler:
 | |
|             name = V.graph.scheduler.mutation_real_name.get(name, name)
 | |
|         assert name not in V.graph.removed_buffers, name
 | |
|         if name in self.inplace_buffers:
 | |
|             return self.inplace_buffers[name].inner_name
 | |
|         return self._lookup("out_ptr", self.output_buffers, name)
 | |
| 
 | |
|     def make_inplace(self, input_name, output_name):
 | |
|         assert output_name not in self.inplace_buffers
 | |
|         if input_name in self.inplace_buffers:
 | |
|             buf = self.inplace_buffers[input_name]
 | |
|             buf.other_names.append(output_name)
 | |
|             self.inplace_buffers[output_name] = buf
 | |
|         else:
 | |
|             buf = InplacedBuffer(
 | |
|                 f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
 | |
|                 [input_name, output_name],
 | |
|             )
 | |
|             self.inplace_buffers[input_name] = buf
 | |
|             self.inplace_buffers[output_name] = buf
 | |
| 
 | |
|     def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
 | |
|         if self.workspace_arg is None:
 | |
|             self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
 | |
|             return "ws_ptr", 0
 | |
| 
 | |
|         offset = self.workspace_arg.nbytes
 | |
|         zero_fill = zero_fill or self.workspace_arg.zero_fill
 | |
|         self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
 | |
|         return "ws_ptr", offset
 | |
| 
 | |
|     def seed_offset(self, name, value):
 | |
|         if value in self.sizevars:
 | |
|             return self.sizevars[value]
 | |
|         if name in self.sizevars.values():
 | |
|             name = (
 | |
|                 f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
 | |
|             )
 | |
|         self.sizevars[value] = name
 | |
|         return name
 | |
| 
 | |
|     def size(self, name):
 | |
|         if str(name) == "seed":
 | |
|             self.sizevars["seed"] = "seed"
 | |
|             return "seed"
 | |
|         return self._lookup("ks", self.sizevars, name)
 | |
| 
 | |
|     def call_names(self):
 | |
|         return chain(
 | |
|             self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
 | |
|         )
 | |
| 
 | |
|     def wrap_ptr_arg(self, buf, dtype):
 | |
|         return buf
 | |
| 
 | |
|     def wrap_size_arg(self, size):
 | |
|         return str(size)
 | |
| 
 | |
|     def cpp_argdefs(self):
 | |
|         from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
 | |
| 
 | |
|         call_args = []
 | |
|         arg_defs = []
 | |
|         arg_types = []
 | |
|         for inplaced in unique(self.inplace_buffers.values()):
 | |
|             if self._buffer_is_marked_removed(inplaced):
 | |
|                 continue
 | |
|             outer = inplaced.other_names[-1]
 | |
|             inner = inplaced.inner_name
 | |
|             dtype = V.graph.get_dtype(outer)
 | |
|             cpp_dtype = DTYPE_TO_CPP[dtype]
 | |
|             arg_defs.append(f"{cpp_dtype}* {inner}")
 | |
|             call_args.append(self.wrap_ptr_arg(outer, dtype))
 | |
|             arg_types.append(f"{cpp_dtype}*")
 | |
|         for outer, inner in self.input_buffers.items():
 | |
|             if outer in self.inplace_buffers:
 | |
|                 continue
 | |
|             dtype = V.graph.get_dtype(outer)
 | |
|             cpp_dtype = DTYPE_TO_CPP[dtype]
 | |
|             arg_defs.append(f"const {cpp_dtype}* {inner}")
 | |
|             call_args.append(self.wrap_ptr_arg(outer, dtype))
 | |
|             arg_types.append(f"const {cpp_dtype}*")
 | |
|         for outer, inner in self.output_buffers.items():
 | |
|             if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
 | |
|                 continue
 | |
|             dtype = V.graph.get_dtype(outer)
 | |
|             cpp_dtype = DTYPE_TO_CPP[dtype]
 | |
|             arg_defs.append(f"{cpp_dtype}* {inner}")
 | |
|             call_args.append(self.wrap_ptr_arg(outer, dtype))
 | |
|             arg_types.append(f"{cpp_dtype}*")
 | |
|         for outer, inner in self.sizevars.items():
 | |
|             arg_defs.append(f"const {INDEX_TYPE} {inner}")
 | |
|             call_args.append(self.wrap_size_arg(outer))
 | |
|             arg_types.append(f"const {INDEX_TYPE}")
 | |
|             if V.graph.wrapper_code:
 | |
|                 V.graph.wrapper_code.ensure_size_computed(outer)
 | |
|         assert self.workspace_arg is None, "Workspace not supported on CPU "
 | |
|         return arg_defs, call_args, arg_types
 | |
| 
 | |
|     def python_argdefs(self):
 | |
|         arg_defs: List[str] = []
 | |
|         call_args: List[str] = []
 | |
|         arg_types: List[torch.dtype] = []
 | |
|         precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
 | |
|         for inplaced in unique(self.inplace_buffers.values()):
 | |
|             if self._buffer_is_marked_removed(inplaced):
 | |
|                 continue
 | |
|             arg_defs.append(inplaced.inner_name)
 | |
|             call_args.append(inplaced.other_names[-1])
 | |
|             arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
 | |
|             precompile_args.append(
 | |
|                 TensorArg(
 | |
|                     name=inplaced.inner_name,
 | |
|                     buffer=inplaced.other_names[-1],
 | |
|                     dtype=V.graph.get_dtype(inplaced.other_names[-1]),
 | |
|                 )
 | |
|             )
 | |
|         for outer, inner in chain(
 | |
|             self.input_buffers.items(), self.output_buffers.items()
 | |
|         ):
 | |
|             if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
 | |
|                 continue
 | |
|             arg_defs.append(inner)
 | |
|             call_args.append(outer)
 | |
|             arg_types.append(V.graph.get_dtype(outer))
 | |
|             precompile_args.append(
 | |
|                 TensorArg(
 | |
|                     name=inner,
 | |
|                     buffer=outer,
 | |
|                     dtype=V.graph.get_dtype(outer),
 | |
|                 )
 | |
|             )
 | |
|         for outer, inner in self.sizevars.items():
 | |
|             arg_defs.append(inner)
 | |
|             call_args.append(outer)
 | |
|             arg_types.append(type(outer))  # type: ignore[arg-type]
 | |
|             precompile_args.append(SizeArg(inner, outer))
 | |
|             if V.graph.wrapper_code:
 | |
|                 V.graph.wrapper_code.ensure_size_computed(outer)
 | |
|         if self.workspace_arg is not None:
 | |
|             arg_defs.append("ws_ptr")
 | |
|             call_args.append("workspace")
 | |
|             precompile_args.append(self.workspace_arg)
 | |
|         return arg_defs, call_args, precompile_args, arg_types
 | |
| 
 | |
|     def aliases(self):
 | |
|         for inplaced in unique(self.inplace_buffers.values()):
 | |
|             if self._buffer_is_marked_removed(inplaced):
 | |
|                 continue
 | |
|             for other in inplaced.other_names:
 | |
|                 if (
 | |
|                     other in V.graph.inplaced_to_remove
 | |
|                     or other in V.kernel.inplaced_to_remove
 | |
|                 ):
 | |
|                     continue
 | |
|                 if other in self.input_buffers:
 | |
|                     yield self.input_buffers[other], inplaced.inner_name
 | |
|                 if other in self.output_buffers:
 | |
|                     yield self.output_buffers[other], inplaced.inner_name
 | |
| 
 | |
|     def is_removed(self, name):
 | |
|         def _is_removed(name, buffers):
 | |
|             return name not in buffers or self._buffer_is_marked_removed(buffers[name])
 | |
| 
 | |
|         return _is_removed(name, self.output_buffers) and _is_removed(
 | |
|             name, self.inplace_buffers
 | |
|         )
 | |
| 
 | |
|     # Includes inplace buffers, excludes removed buffers.  Essentially,
 | |
|     # after you do a call into this kernel, which buffers actually contain
 | |
|     # updated data?  Modeled off of python_argdefs.
 | |
|     def live_output_buffers(self):
 | |
|         live_outs = OrderedSet()  # type: ignore[var-annotated]
 | |
|         for inplaced in unique(self.inplace_buffers.values()):
 | |
|             if self._buffer_is_marked_removed(inplaced):
 | |
|                 continue
 | |
|             live_outs.add(inplaced.other_names[-1])
 | |
|         for outer, inner in self.output_buffers.items():
 | |
|             if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
 | |
|                 continue
 | |
|             live_outs.add(outer)
 | |
|         return live_outs
 | |
| 
 | |
| 
 | |
| class CSEVariable:
 | |
|     """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
 | |
|     To do so, the backends can simply overload `Kernel.create_cse_var`
 | |
|     The "CSEVariable.update_on_args" method gives you a hook for annotations
 | |
|     See example of TritonCSEVariable in triton.py
 | |
|     """
 | |
| 
 | |
|     def __init__(self, name, bounds: ValueRanges[Any]):
 | |
|         assert isinstance(bounds, ValueRanges)
 | |
|         self.name = name
 | |
|         self.bounds = bounds
 | |
|         self.use_count = 1  # track how many tims this expression is used
 | |
| 
 | |
|     def __str__(self):
 | |
|         return self.name
 | |
| 
 | |
|     def __hash__(self) -> int:
 | |
|         return hash(self.name)
 | |
| 
 | |
|     def __eq__(self, other) -> bool:
 | |
|         return type(other) == type(self) and other.name == self.name
 | |
| 
 | |
|     def update_on_args(self, name, args, kwargs):
 | |
|         pass
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f"{self.__class__.__name__}({self.name!r})"
 | |
| 
 | |
| 
 | |
| class CppWrapperKernelArgs(KernelArgs):
 | |
|     def wrap_ptr_arg(self, buf, dtype):
 | |
|         from .cpp_utils import DTYPE_TO_CPP
 | |
| 
 | |
|         if config.abi_compatible:
 | |
|             # In the abi_compatible model, we just return the buf here.
 | |
|             # We will form correct call args later in wrapper.generate_kernel_all.
 | |
|             return buf
 | |
|         else:
 | |
|             return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
 | |
| 
 | |
|     def wrap_size_arg(self, size):
 | |
|         return f"{size}"
 | |
| 
 | |
| 
 | |
| class CSE:
 | |
|     """Common subexpression elimination"""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         prefix="",
 | |
|         suffix="",
 | |
|         name_prefix="tmp",
 | |
|         iter_buffers=None,
 | |
|         store_cache=None,
 | |
|         reduction_cache=None,
 | |
|         varname_map=None,
 | |
|     ):
 | |
|         self.prefix = prefix
 | |
|         self.suffix = suffix
 | |
|         self.cache = {}
 | |
|         self.name_prefix = name_prefix
 | |
|         self.store_cache = store_cache or {}
 | |
|         self.reduction_cache = reduction_cache or {}
 | |
|         self.iter_buffer_ids = iter_buffers or itertools.count()
 | |
|         self.invalidated_stores = OrderedSet()  # type: ignore[var-annotated]
 | |
|         self.varname_map = varname_map or {}
 | |
| 
 | |
|     def invalidate(self, keep_vars: OrderedSet[str]):
 | |
|         for name, tmp in list(self.store_cache.items()):
 | |
|             if tmp not in keep_vars:
 | |
|                 del self.store_cache[name]
 | |
|                 self.invalidated_stores.add(name)
 | |
|         self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
 | |
| 
 | |
|     def clone(self):
 | |
|         # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
 | |
|         return CSE(
 | |
|             prefix=self.prefix,
 | |
|             suffix=self.suffix,
 | |
|             name_prefix=self.name_prefix,
 | |
|             iter_buffers=self.iter_buffer_ids,
 | |
|             store_cache=self.store_cache,
 | |
|             varname_map=self.varname_map,
 | |
|         )
 | |
| 
 | |
|     def generate(
 | |
|         self,
 | |
|         buffer: IndentedBuffer,
 | |
|         expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
 | |
|         *,
 | |
|         bounds: ValueRanges[Any] = ValueRanges.unknown(),
 | |
|         write=True,
 | |
|         assignment=True,
 | |
|     ) -> CSEVariable:
 | |
|         if isinstance(expr, OpsValue):
 | |
|             expr = expr.value
 | |
| 
 | |
|         assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
 | |
|         assert write or assignment
 | |
|         if isinstance(expr, CSEVariable):
 | |
|             # If the expressions were always created with all the information, we could
 | |
|             # assert expr.bounds == bounds, but sometimes the expression is created
 | |
|             # with the loose ValueRanges.unknown(), so we need to tighten the bounds
 | |
|             expr.bounds = expr.bounds.tighten(bounds)
 | |
|             expr.use_count += 1
 | |
|             return expr
 | |
|         cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
 | |
|         var = self.cache.get(cache_key, None)
 | |
|         if not var:
 | |
|             var = self.newvar(bounds)
 | |
|             self.cache[cache_key] = var
 | |
|             if write:
 | |
|                 if V.kernel.current_node:
 | |
|                     V.kernel.current_node.codegen_originating_info(
 | |
|                         buffer, only_once=True
 | |
|                     )
 | |
|                 if isinstance(expr, IndentedBuffer):
 | |
|                     if assignment:
 | |
|                         buffer.writeline(f"{self.prefix}{var} =")
 | |
|                     buffer.splice(expr)
 | |
|                     buffer.writeline(self.suffix)
 | |
|                 else:
 | |
|                     if assignment:
 | |
|                         line = f"{self.prefix}{var} = {expr}{self.suffix}"
 | |
|                     else:
 | |
|                         line = f"{expr}{self.suffix}"
 | |
|                     buffer.writeline(line)
 | |
|         else:
 | |
|             var.bounds = var.bounds.tighten(bounds)
 | |
|             var.use_count += 1
 | |
| 
 | |
|         return var
 | |
| 
 | |
|     def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
 | |
|         var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
 | |
|         var = V.kernel.create_cse_var(var_name, bounds)
 | |
|         self.varname_map[var_name] = var
 | |
|         return var
 | |
| 
 | |
| 
 | |
| class CodeGen:
 | |
|     def __init__(self) -> None:
 | |
|         super().__init__()
 | |
|         self.exit_stack = contextlib.ExitStack()
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self.exit_stack.__enter__()
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
 | |
| 
 | |
| 
 | |
| class ScopedDict:
 | |
|     def __init__(self, original_dict):
 | |
|         self.original_dict = original_dict
 | |
|         self.new_items = {}
 | |
| 
 | |
|     def __getitem__(self, key):
 | |
|         if key in self.new_items:
 | |
|             return self.new_items[key]
 | |
|         return self.original_dict[key]
 | |
| 
 | |
|     def __setitem__(self, key, value):
 | |
|         self.new_items[key] = value
 | |
| 
 | |
|     def __contains__(self, key):
 | |
|         return key in self.new_items or key in self.original_dict
 | |
| 
 | |
|     def get(self, key, default=None):
 | |
|         if key in self.new_items:
 | |
|             return self.new_items[key]
 | |
|         return self.original_dict.get(key, default)
 | |
| 
 | |
| 
 | |
| class Kernel(CodeGen):
 | |
|     newvar_prefix = ""
 | |
|     suffix = ""
 | |
|     overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
 | |
|     # TODO: these look dead, but with all the getattr it's hard to tell...
 | |
|     load_format: None = None
 | |
|     store_format: None = None
 | |
| 
 | |
|     def __init__(self, args=None, increase_kernel_count=True):
 | |
|         super().__init__()
 | |
|         if increase_kernel_count:
 | |
|             metrics.generated_kernel_count += 1
 | |
|         self.args = args or KernelArgs()
 | |
|         self.loads = IndentedBuffer()
 | |
|         self.compute = IndentedBuffer()
 | |
|         self.stores = IndentedBuffer()
 | |
| 
 | |
|         self.num_load = 0
 | |
|         self.num_reduction = 0
 | |
| 
 | |
|         self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
 | |
|         self.must_keep_buffers = OrderedSet()  # type: ignore[var-annotated]
 | |
|         self.store_buffer_names = OrderedSet()  # type: ignore[var-annotated]
 | |
|         self._load_mask = None
 | |
|         self._load_other = None
 | |
|         # OrderedSet in set_current_node
 | |
|         self.current_node = None
 | |
|         self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
 | |
| 
 | |
|         self.removed_buffers = OrderedSet()  # type: ignore[var-annotated]
 | |
|         self.inplaced_to_remove = OrderedSet()  # type: ignore[var-annotated]
 | |
| 
 | |
|         # key: the buffer to write
 | |
|         # value: the buffer to read and whose memory can be reused for
 | |
|         #   the buffer specified by key
 | |
|         self.inplace_update_buffers = {}
 | |
|         # Set minimum number of elements processed per thread.
 | |
|         self.min_elem_per_thread = 1
 | |
|         self.kernel_name = None
 | |
| 
 | |
|     @contextlib.contextmanager
 | |
|     def set_current_node(self, node):
 | |
|         prior = self.current_node
 | |
|         self.current_node = node
 | |
|         self.node_to_bounds = node._body.bounds().get_bounds()
 | |
|         try:
 | |
|             yield
 | |
|         finally:
 | |
|             self.current_node = prior
 | |
| 
 | |
|     @contextlib.contextmanager
 | |
|     def swap_buffers(self, lb, cb=None, sb=None):
 | |
|         def scope_cse(cse):
 | |
|             new_cse = cse.clone()
 | |
|             new_cse.cache = ScopedDict(cse.cache)
 | |
|             new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
 | |
|             new_cse.store_cache = ScopedDict(cse.store_cache)
 | |
|             return new_cse
 | |
| 
 | |
|         if cb is None:
 | |
|             cb = lb
 | |
|         loads = self.loads
 | |
|         compute = self.compute
 | |
|         stores = self.stores
 | |
|         cse = self.cse
 | |
|         self.loads = lb
 | |
|         self.compute = cb
 | |
|         self.stores = sb
 | |
|         self.cse = scope_cse(cse)
 | |
|         try:
 | |
|             yield
 | |
|         finally:
 | |
|             self.loads = loads
 | |
|             self.compute = compute
 | |
|             self.stores = stores
 | |
|             self.cse = cse
 | |
| 
 | |
|     def load(self, name: str, index: sympy.Expr) -> CSEVariable:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def indirect_load(self, name: str, index: sympy.Expr):
 | |
|         """A load the depends on an index we have read"""
 | |
|         prior = self.loads
 | |
|         try:
 | |
|             # put the load in the compute section as it might have deps
 | |
|             self.loads = self.compute
 | |
|             return self.load(name, index)
 | |
|         finally:
 | |
|             self.loads = prior
 | |
| 
 | |
|     def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def store(
 | |
|         self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
 | |
|     ) -> None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def reduction(
 | |
|         self,
 | |
|         dtype: torch.dtype,
 | |
|         src_dtype: torch.dtype,
 | |
|         reduction_type: ReductionType,
 | |
|         value: Union[CSEVariable, Tuple[CSEVariable, ...]],
 | |
|     ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def scan(
 | |
|         self,
 | |
|         dtypes: Tuple[torch.dtype, ...],
 | |
|         combine_fn: Callable[
 | |
|             [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
 | |
|         ],
 | |
|         values: Tuple[CSEVariable, ...],
 | |
|     ) -> Tuple[CSEVariable, ...]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def sort(
 | |
|         self,
 | |
|         dtypes: Tuple[torch.dtype, ...],
 | |
|         values: Tuple[CSEVariable, ...],
 | |
|         stable: bool,
 | |
|         descending: bool,
 | |
|     ) -> Tuple[CSEVariable, ...]:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def var_ranges(self):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def bucketize(
 | |
|         self,
 | |
|         values: CSEVariable,
 | |
|         offsets_name: str,
 | |
|         offsets_size: sympy.Expr,
 | |
|         indexing_dtype: torch.dtype,
 | |
|         right: bool,
 | |
|     ) -> CSEVariable:
 | |
|         """
 | |
|         See [Note: Inductor bucketize op]
 | |
|         """
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     @property
 | |
|     def assert_function(self) -> str:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def indirect_assert(
 | |
|         self,
 | |
|         var: Union[CSEVariable, str],
 | |
|         lower: Optional[str],
 | |
|         upper: Optional[str],
 | |
|         mask: Optional[Union[CSEVariable, str]] = None,
 | |
|     ) -> str:
 | |
|         if isinstance(var, CSEVariable):
 | |
|             var = str(var)
 | |
|         assert isinstance(var, str)
 | |
|         assert lower is None or isinstance(lower, str)
 | |
|         assert upper is None or isinstance(upper, str)
 | |
|         if lower and upper:
 | |
|             # The conditions need to be in parens because of Python's operator precedence.
 | |
|             # It'd be less error-prone to use and/or/not, which is suported by triton
 | |
|             cond = f"({lower} <= {var}) & ({var} < {upper})"
 | |
|             cond_print = f"{lower} <= {var} < {upper}"
 | |
|         elif lower:
 | |
|             cond = f"{lower} <= {var}"
 | |
|             cond_print = cond
 | |
|         else:
 | |
|             assert upper
 | |
|             cond = f"{var} < {upper}"
 | |
|             cond_print = cond
 | |
| 
 | |
|         if mask:
 | |
|             cond = f"({cond}) | ~({mask})"
 | |
| 
 | |
|         return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
 | |
| 
 | |
|     def check_bounds(
 | |
|         self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
 | |
|     ):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def index_to_str(self, index: sympy.Expr) -> str:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __enter__(self):
 | |
|         # TODO: hoist this to top level
 | |
|         class CSEProxy:
 | |
|             self.name = "CSEProxy"
 | |
|             vr_analysis = ValueRangeAnalysis()
 | |
| 
 | |
|             @staticmethod
 | |
|             def __getattr__(name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc]
 | |
|                 def inner(*args, **kwargs):
 | |
|                     bounds = CSEProxy._bound_variable(name, *args, **kwargs)
 | |
| 
 | |
|                     value = getattr(parent_handler, name)(*args, **kwargs)  # type: ignore[has-type]
 | |
| 
 | |
|                     def do_cse(v):
 | |
|                         csevar = V.kernel.cse.generate(
 | |
|                             V.kernel.compute, v, bounds=bounds
 | |
|                         )
 | |
|                         csevar.update_on_args(name, args, kwargs)
 | |
|                         return csevar
 | |
| 
 | |
|                     return pytree.tree_map(do_cse, value)
 | |
| 
 | |
|                 return inner
 | |
| 
 | |
|             @staticmethod
 | |
|             def _bound_variable(name, *args, **kwargs):
 | |
|                 """
 | |
|                 If the variable comes from an FX node, we forward the bound we have already computed
 | |
|                 Else, if the variable when codegen'ing another op, we try to compute its bounds
 | |
|                 """
 | |
|                 from ..select_algorithm import TritonTemplateKernel
 | |
| 
 | |
|                 if isinstance(V.kernel, TritonTemplateKernel):
 | |
|                     return ValueRanges.unknown()
 | |
| 
 | |
|                 fx_node = V.interpreter.current_node
 | |
|                 if fx_node.target == name and self.node_to_bounds is not None:
 | |
|                     assert isinstance(self.node_to_bounds, dict)
 | |
|                     return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
 | |
|                 elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
 | |
|                     # These create lots of inner strings. We would need to compute the bounds at the ops
 | |
|                     # We will also likely not get much from computing VRs on these nodes
 | |
|                     if any(
 | |
|                         s in fx_node.target
 | |
|                         for s in ("set_indirect", "reduction", "scan")
 | |
|                     ):
 | |
|                         return ValueRanges.unknown()
 | |
| 
 | |
|                     # We assume that the inputs come from `ops.` and are not strings. If you want to generate
 | |
|                     # intermediary strings, wrap them in CSE variables with properly initialised bounds.
 | |
| 
 | |
|                     # If there is no FX bound but we know how to compute one we do so
 | |
|                     assert not kwargs
 | |
| 
 | |
|                     def arg_to_bound(x):
 | |
|                         if isinstance(x, CSEVariable):
 | |
|                             return x.bounds
 | |
|                         elif isinstance(x, sympy.Expr):
 | |
|                             return bound_sympy(x)
 | |
|                         else:
 | |
|                             return x
 | |
| 
 | |
|                     arg_bounds = list(map(arg_to_bound, args))
 | |
|                     return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
 | |
|                 else:
 | |
|                     return ValueRanges.unknown()
 | |
| 
 | |
|             @staticmethod
 | |
|             def indirect_indexing(
 | |
|                 var: CSEVariable,
 | |
|                 size: Union[sympy.Expr, int],
 | |
|                 check: bool = True,
 | |
|                 wrap_neg=True,
 | |
|             ):
 | |
|                 if isinstance(size, int):
 | |
|                     size = sympy.Integer(size)
 | |
|                 assert isinstance(size, sympy.Expr), size
 | |
|                 # Skip CSE since this doesn't return an expression
 | |
| 
 | |
|                 if var.bounds.lower < 0:  # type: ignore[operator]
 | |
|                     if wrap_neg:
 | |
|                         stm = ops.add(var, ops.index_expr(size, torch.long))
 | |
|                         # Mixed negative and non-negative
 | |
|                         if var.bounds.upper >= 0:  # type: ignore[operator]
 | |
|                             lt = ops.lt(var, 0)
 | |
|                             stm = ops.where(lt, stm, var)
 | |
|                     else:
 | |
|                         stm = var
 | |
| 
 | |
|                     # Propagate bounds as we know how to compute them properly
 | |
|                     new_bounds = ValueRanges.unknown()
 | |
|                     if var.bounds != ValueRanges.unknown() and isinstance(
 | |
|                         size, sympy.Number
 | |
|                     ):
 | |
|                         # Take the negative part of the bound and add size to it
 | |
|                         # Then take union of that and the positive part
 | |
|                         # This is a tighter bound than that of a generic ops.where, as we have info on the cond
 | |
|                         neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
 | |
|                         new_bounds = ValueRanges(
 | |
|                             neg_bounds.lower + size, neg_bounds.upper + size
 | |
|                         )
 | |
|                         # We don't have a good way of representing the empty range
 | |
|                         if var.bounds.upper >= 0:  # type: ignore[operator]
 | |
|                             pos = var.bounds & ValueRanges(0, int_oo)
 | |
|                             new_bounds = new_bounds | pos
 | |
| 
 | |
|                     var = self.cse.generate(self.compute, stm, bounds=new_bounds)
 | |
| 
 | |
|                 sympy_var = parent_handler.indirect_indexing(var, size, check)
 | |
|                 if generate_assert(check):
 | |
|                     assert_lower = not (var.bounds.lower >= 0)
 | |
|                     # value ranges cannot x < s when x and s are symbols
 | |
|                     assert_upper = not isinstance(size, sympy.Number) or not (
 | |
|                         var.bounds.upper < size
 | |
|                     )
 | |
|                     self.check_bounds(sympy_var, size, assert_lower, assert_upper)
 | |
|                 return sympy_var
 | |
| 
 | |
|             @staticmethod
 | |
|             def check_bounds(
 | |
|                 expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
 | |
|             ):
 | |
|                 return self.check_bounds(expr, size, lower, upper)
 | |
| 
 | |
|             @staticmethod
 | |
|             def load(name: str, index: sympy.Expr) -> CSEVariable:
 | |
|                 if name in self.cse.invalidated_stores:
 | |
|                     # A load from an invalidated store requires us to
 | |
|                     # keep the actual buffer around
 | |
|                     V.kernel.must_keep_buffers.add(name)
 | |
|                 if free_symbol_is_type(index, SymT.TMP):
 | |
|                     return self.indirect_load(name, index)
 | |
|                 store_cache = self.cse.store_cache
 | |
|                 if name in store_cache:
 | |
|                     return store_cache[name]
 | |
|                 out = self.load(name, index)
 | |
|                 # count load that is not in the store_cache, and also not in the
 | |
|                 # cse cache.
 | |
|                 if out.use_count == 1:
 | |
|                     self.num_load += 1
 | |
|                 return out
 | |
| 
 | |
|             @staticmethod
 | |
|             def _update_store_cache(name: str, value: CSEVariable):
 | |
|                 self.cse.store_cache[name] = value
 | |
|                 if self.current_node and name in V.graph.name_to_buffer:
 | |
|                     buf = self.current_node.get_output(name)
 | |
|                     for other_name in buf.get_mutations():
 | |
|                         self.cse.store_cache[other_name] = value
 | |
| 
 | |
|             @staticmethod
 | |
|             def store(
 | |
|                 name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
 | |
|             ) -> None:
 | |
|                 self.store_buffer_names.add(name)
 | |
|                 if mode is None:
 | |
|                     CSEProxy._update_store_cache(name, value)
 | |
|                 if name not in V.graph.removed_buffers:
 | |
|                     return self.store(name, index, value, mode=mode)
 | |
|                 else:
 | |
|                     return None  # type: ignore[return-value]
 | |
| 
 | |
|             @staticmethod
 | |
|             def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
 | |
|                 self.store_buffer_names.add(name)
 | |
|                 CSEProxy._update_store_cache(name, value)
 | |
| 
 | |
|                 if name not in V.graph.removed_buffers:
 | |
|                     return self.store_reduction(name, index, value)
 | |
| 
 | |
|             @staticmethod
 | |
|             def reduction(
 | |
|                 dtype: torch.dtype,
 | |
|                 src_dtype: torch.dtype,
 | |
|                 reduction_type: ReductionType,
 | |
|                 value: Union[CSEVariable, Tuple[CSEVariable, ...]],
 | |
|             ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
 | |
|                 self.num_reduction += 1
 | |
|                 return self.reduction(dtype, src_dtype, reduction_type, value)
 | |
| 
 | |
|             @staticmethod
 | |
|             def scan(
 | |
|                 dtypes: Tuple[torch.dtype, ...],
 | |
|                 combine_fn: Callable[
 | |
|                     [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
 | |
|                     Tuple[CSEVariable, ...],
 | |
|                 ],
 | |
|                 values: Tuple[CSEVariable, ...],
 | |
|             ) -> Tuple[CSEVariable, ...]:
 | |
|                 return self.scan(dtypes, combine_fn, values)
 | |
| 
 | |
|             @staticmethod
 | |
|             def sort(
 | |
|                 dtypes: Tuple[torch.dtype, ...],
 | |
|                 values: Tuple[CSEVariable, ...],
 | |
|                 stable: bool,
 | |
|                 descending: bool,
 | |
|             ) -> Tuple[CSEVariable, ...]:
 | |
|                 return self.sort(dtypes, values, stable, descending)
 | |
| 
 | |
|             @staticmethod
 | |
|             def bucketize(
 | |
|                 values: CSEVariable,
 | |
|                 offsets_name: str,
 | |
|                 offsets_size: sympy.Expr,
 | |
|                 indexing_dtype: torch.dtype,
 | |
|                 right: bool,
 | |
|             ) -> CSEVariable:
 | |
|                 """
 | |
|                 [Note: Inductor bucketize op]
 | |
| 
 | |
|                 Given values (tensor) and offsets_name (reference to the name of a 1D
 | |
|                 tensor), calculate the bucket that each value belongs to.
 | |
| 
 | |
|                 e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
 | |
|                 return =        [ 0, 1, 1, 1, 1, 3, 3, 4].
 | |
| 
 | |
|                 When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
 | |
|                 When right == True,  bucket i refers to range [offsets[i], offsets[i+1]).
 | |
| 
 | |
|                 Offsets must be non-decreasing or the result is undefined.
 | |
|                 """
 | |
|                 return self.bucketize(
 | |
|                     values, offsets_name, offsets_size, indexing_dtype, right
 | |
|                 )
 | |
| 
 | |
|         # Use mypy to check protocol implemented correctly
 | |
|         def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
 | |
|             return h
 | |
| 
 | |
|         super().__enter__()
 | |
|         assert self.overrides
 | |
|         parent_handler = self.overrides(V.get_ops_handler())
 | |
|         self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
 | |
|         self.exit_stack.enter_context(V.set_kernel_handler(self))
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         """
 | |
|         Note that V.graph.scheduler can be None when codegening triton template
 | |
|         kernels.
 | |
|         """
 | |
|         if V.graph.scheduler:
 | |
|             V.graph.scheduler.remove_kernel_local_buffers()
 | |
|         super().__exit__(exc_type, exc_val, exc_tb)
 | |
| 
 | |
|     def rename_indexing(self, index) -> sympy.Expr:
 | |
|         # adds the necessary kernel args for index expressions
 | |
|         # and renames variables in index expressions to kernel arg names
 | |
|         if isinstance(index, (list, tuple)):
 | |
|             return [self.rename_indexing(x) for x in index]  # type: ignore[return-value]
 | |
|         index = V.graph.sizevars.simplify(index)
 | |
|         sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
 | |
|         replacements = {
 | |
|             x: self.args.size(x)
 | |
|             for x in sorted_symbols
 | |
|             if symbol_is_type(
 | |
|                 x,
 | |
|                 (
 | |
|                     SymT.UNBACKED_INT,
 | |
|                     SymT.SIZE,
 | |
|                     SymT.PRECOMPUTED_SIZE,
 | |
|                 ),
 | |
|             )
 | |
|         }
 | |
|         return sympy_subs(index, replacements)
 | |
| 
 | |
|     def create_cse_var(self, *args, **kwargs):
 | |
|         return CSEVariable(*args, **kwargs)
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class OptimizationContext:
 | |
|     key: ClassVar[str] = "opt_ctx"
 | |
| 
 | |
|     dtype: Optional[torch.dtype] = None
 | |
|     ops_name: str = ""
 | |
| 
 | |
| 
 | |
| @functools.lru_cache(None)
 | |
| def jinja2_env():
 | |
|     try:
 | |
|         import jinja2
 | |
| 
 | |
|         return jinja2.Environment(
 | |
|             undefined=jinja2.StrictUndefined,
 | |
|         )
 | |
|     except ImportError:
 | |
|         return None
 | |
| 
 | |
| 
 | |
| class KernelTemplate:
 | |
|     """
 | |
|     Base class for defining kernel templates.
 | |
| 
 | |
|     Children classes: TritonTemplate, CUDATemplate
 | |
|     """
 | |
| 
 | |
|     @staticmethod
 | |
|     def indent_except_first(source: str, num_indents: int, indents_spacing=4):
 | |
|         lines = source.splitlines(True)
 | |
|         if len(lines) > 1:
 | |
|             lines[1:] = [
 | |
|                 (" " * indents_spacing * num_indents) + line for line in lines[1:]
 | |
|             ]
 | |
|         return "".join(lines)
 | |
| 
 | |
|     @staticmethod
 | |
|     def _template_from_string(source):
 | |
|         env = jinja2_env()
 | |
|         if env is not None:
 | |
|             env.filters["indent_except_first"] = KernelTemplate.indent_except_first
 | |
|             from jinja2 import TemplateSyntaxError
 | |
| 
 | |
|             class DetailedTemplateSyntaxError(TemplateSyntaxError):
 | |
|                 def __init__(self, original_error):
 | |
|                     super().__init__(
 | |
|                         original_error.message,
 | |
|                         original_error.lineno,
 | |
|                         original_error.name,
 | |
|                         original_error.filename,
 | |
|                     )
 | |
|                     self.original_error = original_error
 | |
| 
 | |
|                 def __str__(self):
 | |
|                     error_info = f"Error in template at line {self.lineno}\n"
 | |
|                     error_info += f"Error message: {self.message}\n"
 | |
|                     if hasattr(self.original_error, "source"):
 | |
|                         lines = self.original_error.source.split("\n")
 | |
|                         error_info += "Context:\n"
 | |
|                         start = max(0, self.lineno - 2)
 | |
|                         end = min(len(lines), self.lineno + 2)
 | |
|                         for i in range(start, end):
 | |
|                             if i == self.lineno - 1:
 | |
|                                 error_info += f"{i+1}: --> {lines[i]}\n"
 | |
|                                 if hasattr(self.original_error, "column"):
 | |
|                                     error_info += (
 | |
|                                         "     "
 | |
|                                         + " " * (self.original_error.column - 1)
 | |
|                                         + "^\n"
 | |
|                                     )
 | |
|                             else:
 | |
|                                 error_info += f"{i+1}:     {lines[i]}\n"
 | |
|                     return error_info
 | |
| 
 | |
|             try:
 | |
|                 return env.from_string(source)
 | |
|             except TemplateSyntaxError as e:
 | |
|                 raise DetailedTemplateSyntaxError(e) from e
 | |
| 
 | |
|         return None
 | |
| 
 | |
|     @staticmethod
 | |
|     def _fake_get_dtype(fake_out):
 | |
|         _get_dtype_real = V.graph.get_dtype
 | |
| 
 | |
|         def get_dtype(name):
 | |
|             if name == fake_out.get_name():
 | |
|                 return fake_out.get_dtype()
 | |
|             return _get_dtype_real(name)
 | |
| 
 | |
|         return get_dtype
 | |
| 
 | |
|     def __init__(self, name: str):
 | |
|         self.name = name
 | |
| 
 | |
|     def maybe_append_choice(self, choices, **kwargs):
 | |
|         """
 | |
|         Maybe generates a new ChoiceCaller and appends it into existing choices.
 | |
| 
 | |
|         choices: A list of ChoiceCallers.
 | |
|         kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
 | |
|         """
 | |
| 
 | |
|         try:
 | |
|             choices.append(self.generate(**kwargs))
 | |
|         except NotImplementedError as e:
 | |
|             pass
 | |
| 
 | |
|     def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
 | |
|         """
 | |
|         Generates a ChoiceCaller instance from the given arguments.
 | |
|         """
 | |
| 
 | |
|         raise NotImplementedError
 |