import contextlib import dataclasses import functools import itertools import logging import operator import re from itertools import chain from typing import ( Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Tuple, TYPE_CHECKING, Union, ) import sympy from sympy.printing.printer import Printer import torch import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree from torch.utils._sympy.value_ranges import ValueRanges from .. import config, metrics from ..utils import ( DeferredLineBase, do_bench, free_symbol_startswith, IndentedBuffer, sympy_dot, sympy_index_symbol, sympy_subs, unique, ) from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V if TYPE_CHECKING: from ..ir import TensorBox schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") def data_type_logger(msg): if schedule_log.isEnabledFor(logging.DEBUG): schedule_log.debug("Data type propagation: %s", msg) @dataclasses.dataclass class WorkspaceArg: """A temporary buffer used for a single kernel, then discarded. Not registered as a traditional buffer since there are no users, so it would be dead code eliminated. """ nbytes: sympy.Expr zero_fill: bool @dataclasses.dataclass class TensorArg: name: str buffer: str dtype: torch.dtype offset: sympy.Expr = sympy.Integer(0) @dataclasses.dataclass class SizeArg: name: str expr: sympy.Expr @dataclasses.dataclass class DeviceCodegen: scheduling: type wrapper_codegen: type KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] device_codegens: Dict[str, DeviceCodegen] = {} class DeviceOpOverrides: def import_get_raw_stream_as(self, name): raise NotImplementedError() def set_device(self, device_idx): raise NotImplementedError() def synchronize(self): raise NotImplementedError() def device_guard(self, device_idx): raise NotImplementedError() device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. # For any new backend looking to integrate with Inductor, customization of these two main # parts are necessary to generate its specific code. # # Kernel code generation is determined by different Scheduling. Consequently, a new # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. # # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, # and override specific member functions to create backend-specific Python wrapper code. # # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces # provide flexibility to the backend. A backend can choose to implement these classes from scratch, # or reuse them by extending and overriding as necessary. And Inductor provides the registration API, # register_backend_for_device, to equip a new backend at runtime. # # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. # This backend can be used as a reference: # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 def register_backend_for_device( device: str, device_scheduling: type, device_wrapper_codegen: type ): device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen) def get_scheduling_for_device(device: str): return device_codegens[device].scheduling if device in device_codegens else None def get_wrapper_codegen_for_device(device: str): return ( device_codegens[device].wrapper_codegen if device in device_codegens else None ) def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): from ..ir import FlexibleLayout # added contiguous index prevents reordering return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): device_op_overrides_dict[device] = device_op_overrides def get_device_op_overrides(device: str): assert isinstance(device, str) if not device_op_overrides_dict.keys(): from .cuda import device_op_overrides # noqa: F401 if device in device_op_overrides_dict.keys(): return device_op_overrides_dict[device] return DeviceOpOverrides() @functools.lru_cache(None) def boolean_ops(): return ( "is_inf", "is_nan", "bitwise_xor", "logical_not", "signbit", "le", "lt", "ge", "gt", "eq", "ne", ) DTYPE_TO_COMPUTATION_DTYPE = { torch.bfloat16: torch.float, torch.float16: torch.float, **{ dtype: dtype for dtype in [ torch.bool, torch.float32, torch.float64, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.uint16, torch.uint32, torch.uint64, ] }, } class DataTypePropagation: def __init__(self, body) -> None: self.body = body self.graphs: Dict[Union[Callable[..., Any], str], Any] = { "root": body.root_block.graph } for k, v in body.subblocks.items(): self.graphs[k] = v.graph def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): inputs = node.all_input_nodes input_nodes = [ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" ] if len(input_nodes) == 0: return None all_input_nodes_propogated = all( OptimizationContext.key in n.meta and n.meta[OptimizationContext.key].dtype is not None for n in input_nodes ) if not all_input_nodes_propogated: return None return functools.reduce( torch.promote_types, [n.meta[OptimizationContext.key].dtype for n in input_nodes], ) def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): sub_graph = self.graphs[node.target] dtype = self.propagate_graph(sub_graph) assert dtype return dtype def deduce_node_dtype(self, node: torch.fx.Node): if node.target in boolean_ops(): return torch.bool if node.op == "placeholder": return None if node.target == "output": # we can infer output node if it only have 1 arg if len(node.args) != 1: return None if node.target in ( "to_dtype", "index_expr", ): return node.args[-1] if node.target in ( "rand", "randn", ): return torch.float if node.target in ( "get_index", "index_expr", ): return torch.int64 if node.target in ( "load", "store", "store_reduction", ): buf_name = node.args[1] return V.graph.get_dtype(buf_name) # type: ignore[arg-type] if node.target == operator.getitem: return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] assert isinstance(node.target, str) if node.target == "reduction": return node.args[1] if node.target == "constant": return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index] if node.target.startswith("masked_subblock"): return self.deduce_node_dtype_by_subgraph(node) return self.deduce_node_dtype_by_inputs(node) def propagate_graph(self, graph: torch.fx.Graph): assert graph.nodes graph_dtype = None # For masked_subblock, we use output's dtype to represent # the dtype of this subgraph. For other cases, graph_dtype # might be None for node in graph.nodes: if OptimizationContext.key in node.meta: opt_ctx = node.meta[OptimizationContext.key] else: opt_ctx = OptimizationContext() opt_ctx.dtype = self.deduce_node_dtype(node) node.meta[OptimizationContext.key] = opt_ctx if node.target == "output": graph_dtype = opt_ctx.dtype return graph_dtype def propagate(self): self.propagate_graph(self.graphs["root"]) @classmethod def propagate_loopbody(cls, body): return cls(body).propagate() @classmethod def propagate_scheduler_node(cls, node): from ..ir import LoopBody from ..scheduler import SchedulerNode assert isinstance(node, SchedulerNode) assert isinstance(node._body, LoopBody) DataTypePropagation.propagate_loopbody(node._body) class ExprPrinter(Printer): @staticmethod def paren(string): def all_in_parens(string): if string[0] != "(" or len(string) < 2: return False count = 1 for i, char in enumerate(string[1:]): if char == "(": count += 1 elif char == ")": count -= 1 if count == 0 and i != len(string) - 2: return False assert count == 0 return True if ( isinstance(string, CSEVariable) or re.match(r"^[a-z0-9_.]+$", string, re.I) or re.match(r"^\([^)]*\)$", string, re.I) or string == "" ): return string # don't put extra parens for strings that are already wrapped in parens if all_in_parens(string): return string return f"({string})" def _print_Infinity(self, expr): return "math.inf" def _print_NegativeInfinity(self, expr): return "-math.inf" def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) def _print_Mul(self, expr): return "*".join(map(self.paren, map(self._print, expr.args))) def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) def _print_FloorDiv(self, expr): raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) def _print_GreaterThan(self, expr): # GreaterThan: >= # StrictlyGreaterThan: > # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" class PythonPrinter(ExprPrinter): def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) mod = self.paren(self.doprint(mod)) if div != "1": x = f"({x} // {div})" return f"{x} % {mod}" def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_Pow(self, expr): # Pow() confuses triton base, exp = expr.args # NB: Remember this is sizevar computation! You don't typically # expect to have to do floating point computation including exponents # in sizevar compute. Instead of adding support for floating # point pow, you should make upstream retranslate the Sympy expression # into Tensor expressions earlier and do that instead. if exp == 0.5: return self._helper_sqrt(base) elif exp == -0.5: return "1/" + self._helper_sqrt(base) base = self._print(base) assert exp == int(exp), exp exp = int(exp) if exp > 0: return "*".join([self.paren(base)] * exp) elif exp < 0: return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) else: # exp == 0 return "1" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" def _print_Min(self, expr): assert len(expr.args) >= 2 return f"min({', '.join(map(self._print, expr.args))})" def _print_cos(self, expr): assert len(expr.args) == 1 return f"math.cos({self._print(expr.args[0])})" def _print_cosh(self, expr): assert len(expr.args) == 1 return f"math.cosh({self._print(expr.args[0])})" def _print_acos(self, expr): assert len(expr.args) == 1 return f"math.acos({self._print(expr.args[0])})" def _print_sin(self, expr): assert len(expr.args) == 1 return f"math.sin({self._print(expr.args[0])})" def _print_sinh(self, expr): assert len(expr.args) == 1 return f"math.sinh({self._print(expr.args[0])})" def _print_asin(self, expr): assert len(expr.args) == 1 return f"math.asin({self._print(expr.args[0])})" def _print_tan(self, expr): assert len(expr.args) == 1 return f"math.tan({self._print(expr.args[0])})" def _print_tanh(self, expr): assert len(expr.args) == 1 return f"math.tanh({self._print(expr.args[0])})" def _print_atan(self, expr): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" def _print_Round(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 number, ndigits = expr.args assert isinstance(ndigits, sympy.Integer) return f"round({self._print(number)}, {ndigits})" class OpOverrides: def __init__(self, parent): super().__init__() self._parent = parent def __getattr__(self, item): return getattr(self._parent, item) @staticmethod def identity(value): # used to trigger cse return value @staticmethod def constant(value, dtype): return repr(value) @staticmethod def reciprocal(x): return ops.truediv("1", x) @staticmethod def square(x): return ops.mul(x, x) @staticmethod def bitwise_not(x): return f"~{ExprPrinter.paren(x)}" @staticmethod def logical_not(a): return f"{ExprPrinter.paren(a)} == 0" @staticmethod def bitwise_and(x, y): return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" @staticmethod def bitwise_or(x, y): return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" @staticmethod def bitwise_xor(x, y): return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" @staticmethod def bitwise_left_shift(x, y): return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" @staticmethod def bitwise_right_shift(x, y): return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" @staticmethod def remainder(a, b): r = ops.mod(a, b) return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r) @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) @classmethod def _initialize_pointwise_overrides(cls, target): assert target in {"triton", "cpp", "cppvec"}, target def pointwise_factory_1(impl): def func(x): return impl.format(x=x) return func def pointwise_factory_2(impl): def func(x, y): return impl.format(x=x, y=y) return func for funcname, data in pointwise_overrides_data.items(): impl = getattr(data, target) if isinstance(impl, str): nof_args = 2 if "{y}" in impl else 1 # extend the following dictionary with factory # functions for a specific number of arguments as # needed: factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args] setattr(cls, funcname, staticmethod(factory(impl))) @dataclasses.dataclass class OverridesData: name: str cpp: str triton: Optional[str] = None # None when not impl in libdevice/triton cppvec: Optional[str] = None # None when not impl in aten/.../vec type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) pointwise_overrides_data: Dict[str, OverridesData] = dict( airy_ai=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="airy_ai_forward({x})", name="special_airy_ai", ), bessel_j0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="bessel_j0_forward({x})", triton="tl.math.j0({x})", name="special_bessel_j0", ), bessel_j1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="bessel_j1_forward({x})", triton="tl.math.j1({x})", name="special_bessel_j1", ), bessel_y0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="bessel_y0_forward({x})", triton="tl.math.y0({x})", name="special_bessel_y0", ), bessel_y1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="bessel_y1_forward({x})", triton="tl.math.y1({x})", name="special_bessel_y1", ), digamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_digamma({x})", cppvec="{x}.digamma()", name="digamma", ), # no cpp nor triton implementation for entr, it is defined as decomposition # erf, erfc erfcx=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_erfcx({x})", triton="tl.math.erfcx({x})", name="special_erfcx", ), # erfinv, exp2, expit, gammaln igamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_igamma({x}, {y})", name="igamma", ), igammac=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_igammac({x}, {y})", name="igammac", ), gammainc=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_igamma({x}, {y})", name="special_gammainc", ), gammaincc=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_igammac({x}, {y})", name="special_gammaincc", ), i0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_i0({x})", triton="tl.math.cyl_bessel_i0({x})", cppvec="{x}.i0()", name="i0", ), i0e=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_i0e({x})", cppvec="{x}.i0e()", name="special_i0e", ), i1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_i1({x})", triton="tl.math.cyl_bessel_i1({x})", name="special_i1", ), i1e=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_i1e({x})", name="special_i1e", ), log_ndtr=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_log_ndtr({x})", name="special_log_ndtr", ), # logit modified_bessel_i0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="modified_bessel_i0_forward({x})", triton="tl.math.cyl_bessel_i0({x})", name="special_modified_bessel_i0", ), modified_bessel_i1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="modified_bessel_i1_forward({x})", triton="tl.math.cyl_bessel_i1({x})", name="special_modified_bessel_i1", ), modified_bessel_k0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="modified_bessel_k0_forward({x})", name="special_modified_bessel_k0", ), modified_bessel_k1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="modified_bessel_k1_forward({x})", name="special_modified_bessel_k1", ), # multigamma ndtr=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_ndtr({x})", name="special_ndtr", ), ndtri=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_ndtri({x})", name="special_ndtri", ), polygamma=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="calc_polygamma({y}, {x})", name="polygamma", ), # psi - alias to digamma # round scaled_modified_bessel_k0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="scaled_modified_bessel_k0_forward({x})", name="special_scaled_modified_bessel_k0", ), scaled_modified_bessel_k1=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="scaled_modified_bessel_k1_forward({x})", name="special_scaled_modified_bessel_k1", ), # sinc spherical_bessel_j0=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="spherical_bessel_j0_forward({x})", name="special_spherical_bessel_j0", ), zeta=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="zeta({x}, {y})", name="special_zeta", ), chebyshev_polynomial_t=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="chebyshev_polynomial_t_forward({x}, {y})", name="special_chebyshev_polynomial_t", ), chebyshev_polynomial_u=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="chebyshev_polynomial_u_forward({x}, {y})", name="special_chebyshev_polynomial_u", ), chebyshev_polynomial_v=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="chebyshev_polynomial_v_forward({x}, {y})", name="special_chebyshev_polynomial_v", ), chebyshev_polynomial_w=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="chebyshev_polynomial_w_forward({x}, {y})", name="special_chebyshev_polynomial_w", ), legendre_polynomial_p=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="legendre_polynomial_p_forward({x}, {y})", name="special_legendre_polynomial_p", ), shifted_chebyshev_polynomial_t=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_t", ), shifted_chebyshev_polynomial_u=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_u", ), shifted_chebyshev_polynomial_v=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_v", ), shifted_chebyshev_polynomial_w=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})", name="special_shifted_chebyshev_polynomial_w", ), hermite_polynomial_h=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="hermite_polynomial_h_forward({x}, {y})", name="special_hermite_polynomial_h", ), hermite_polynomial_he=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="hermite_polynomial_he_forward({x}, {y})", name="special_hermite_polynomial_he", ), laguerre_polynomial_l=OverridesData( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, cpp="laguerre_polynomial_l_forward({x}, {y})", name="special_laguerre_polynomial_l", ), ) # Use mypy to check protocol implemented correctly def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: return h class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" def __init__(self, name, line): super().__init__(line) self.name = name assert not isinstance(line, DeferredLineBase) def __call__(self): if all( self.name not in x for x in ( V.graph.removed_buffers, V.kernel.removed_buffers, V.graph.inplaced_to_remove, V.kernel.inplaced_to_remove, ) ): return self.line return None def _new_line(self, line): return DeferredLine(self.name, line) class BracesBuffer(IndentedBuffer): def indent(self, offset=1): @contextlib.contextmanager def ctx(): for _ in range(offset): self.writeline("{") self._indent += 1 for _ in range(-offset): self._indent -= 1 self.writeline("}") yield for _ in range(-offset): self.writeline("{") self._indent += 1 for _ in range(offset): self._indent -= 1 self.writeline("}") return ctx() class InplacedBuffer(NamedTuple): inner_name: str other_names: List[str] class KernelArgs: @staticmethod def _lookup(prefix, odict, name): assert isinstance(name, (str, sympy.Symbol)) if name not in odict: odict[name] = f"{prefix}{len(odict)}" return odict[name] def __init__(self, sizevars=None): self.input_buffers = dict() self.output_buffers = dict() self.inplace_buffers = dict() self.sizevars = sizevars or dict() self.workspace_arg = None def __repr__(self): return "KernelArgs({})".format( ", ".join( map( repr, [ self.input_buffers, self.output_buffers, self.inplace_buffers, self.sizevars, ], ) ) ) def _buffer_is_marked_removed(self, name): return isinstance(name, str) and name.startswith("REMOVED") def input(self, name): if V.graph.scheduler: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.output_buffers: return self.output_buffers[name] if name in self.inplace_buffers: return self.inplace_buffers[name].inner_name if name.startswith("seed"): return self._lookup("seed", self.input_buffers, name) return self._lookup("in_ptr", self.input_buffers, name) def output(self, name): if V.graph.scheduler: name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name if name in self.inplace_buffers: return self.inplace_buffers[name].inner_name return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name, output_name): assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] buf.other_names.append(output_name) self.inplace_buffers[output_name] = buf else: buf = InplacedBuffer( f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", [input_name, output_name], ) self.inplace_buffers[input_name] = buf self.inplace_buffers[output_name] = buf def workspace(self, nbytes: sympy.Expr, zero_fill: bool): if self.workspace_arg is None: self.workspace_arg = WorkspaceArg(nbytes, zero_fill) return "ws_ptr", 0 offset = self.workspace_arg.nbytes zero_fill = zero_fill or self.workspace_arg.zero_fill self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) return "ws_ptr", offset def seed_offset(self, name, value): if value in self.sizevars: return self.sizevars[value] if name in self.sizevars.values(): name = ( f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" ) self.sizevars[value] = name return name def size(self, name): if str(name) == "seed": self.sizevars["seed"] = "seed" return "seed" return self._lookup("ks", self.sizevars, name) def call_names(self): return chain( self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() ) def wrap_ptr_arg(self, buf, dtype): return buf def wrap_size_arg(self, size): return str(size) def cpp_argdefs(self): from .cpp import DTYPE_TO_CPP, INDEX_TYPE call_args = [] arg_defs = [] arg_types = [] for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue outer = inplaced.other_names[-1] inner = inplaced.inner_name dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.input_buffers.items(): if outer in self.inplace_buffers: continue dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"const {cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"const {cpp_dtype}*") for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): arg_defs.append(f"const {INDEX_TYPE} {inner}") call_args.append(self.wrap_size_arg(outer)) arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert self.workspace_arg is None, "Workspace not supported on CPU " return arg_defs, call_args, arg_types def python_argdefs(self): arg_defs = [] call_args = [] precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue arg_defs.append(inplaced.inner_name) call_args.append(inplaced.other_names[-1]) precompile_args.append( TensorArg( name=inplaced.inner_name, buffer=inplaced.other_names[-1], dtype=V.graph.get_dtype(inplaced.other_names[-1]), ) ) for outer, inner in chain( self.input_buffers.items(), self.output_buffers.items() ): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue arg_defs.append(inner) call_args.append(outer) precompile_args.append( TensorArg( name=inner, buffer=outer, dtype=V.graph.get_dtype(outer), ) ) for outer, inner in self.sizevars.items(): arg_defs.append(inner) call_args.append(outer) precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) if self.workspace_arg is not None: arg_defs.append("ws_ptr") call_args.append("workspace") precompile_args.append(self.workspace_arg) return arg_defs, call_args, precompile_args def aliases(self): for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue for other in inplaced.other_names: if ( other in V.graph.inplaced_to_remove or other in V.kernel.inplaced_to_remove ): continue if other in self.input_buffers: yield self.input_buffers[other], inplaced.inner_name if other in self.output_buffers: yield self.output_buffers[other], inplaced.inner_name def is_removed(self, name): def _is_removed(name, buffers): return name not in buffers or self._buffer_is_marked_removed(buffers[name]) return _is_removed(name, self.output_buffers) and _is_removed( name, self.inplace_buffers ) # Includes inplace buffers, excludes removed buffers. Essentially, # after you do a call into this kernel, which buffers actually contain # updated data? Modeled off of python_argdefs. def live_output_buffers(self): live_outs = set() for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue live_outs.add(inplaced.other_names[-1]) for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue live_outs.add(outer) return live_outs class CSEVariable: """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. To do so, the backends can simply overload `Kernel.create_cse_var` The "CSEVariable.update_on_args" method gives you a hook for annotations See example of TritonCSEVariable in triton.py """ def __init__(self, name, bounds: ValueRanges[Any]): assert isinstance(bounds, ValueRanges) self.name = name self.bounds = bounds def __str__(self): return self.name def __hash__(self) -> int: return hash(self.name) def __eq__(self, other) -> bool: return type(other) == type(self) and other.name == self.name def update_on_args(self, name, args, kwargs): pass class CppWrapperKernelArgs(KernelArgs): def wrap_ptr_arg(self, buf, dtype): from .cpp import DTYPE_TO_CPP if config.abi_compatible: # In the abi_compatible model, we just return the buf here. # We will form correct call args later in wrapper.generate_kernel_all. return buf else: return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" def wrap_size_arg(self, size): return f"{size}" class CSE: """Common subexpression elimination""" def __init__( self, prefix="", suffix="", name_prefix="tmp", iter_buffers=None, store_cache=None, reduction_cache=None, varname_map=None, ): self.prefix = prefix self.suffix = suffix self.cache = {} self.name_prefix = name_prefix self.store_cache = store_cache or {} self.reduction_cache = reduction_cache or {} self.iter_buffer_ids = iter_buffers or itertools.count() self.invalidated_stores = set() self.varname_map = varname_map or {} def invalidate(self, keep_vars: Set[str]): for name, tmp in list(self.store_cache.items()): if tmp not in keep_vars: del self.store_cache[name] self.invalidated_stores.add(name) self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} def clone(self): # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional return CSE( prefix=self.prefix, suffix=self.suffix, name_prefix=self.name_prefix, iter_buffers=self.iter_buffer_ids, store_cache=self.store_cache, varname_map=self.varname_map, ) def generate( self, buffer: IndentedBuffer, expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], *, bounds: ValueRanges[Any] = ValueRanges.unknown(), write=True, assignment=True, ) -> CSEVariable: if isinstance(expr, OpsValue): expr = expr.value assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) assert write or assignment if isinstance(expr, CSEVariable): # If the expressions were always created with all the information, we could # assert expr.bounds == bounds, but sometimes the expression is created # with the loose ValueRanges.unknown(), so we need to tighten the bounds expr.bounds = expr.bounds.tighten(bounds) return expr cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr var = self.cache.get(cache_key, None) if not var: var = self.newvar(bounds) if assignment else None self.cache[cache_key] = var if write: if V.kernel.current_node: V.kernel.current_node.codegen_originating_info( buffer, only_once=True ) if isinstance(expr, IndentedBuffer): if assignment: buffer.writeline(f"{self.prefix}{var} =") buffer.splice(expr) buffer.writeline(self.suffix) else: if assignment: line = f"{self.prefix}{var} = {expr}{self.suffix}" else: line = f"{expr}{self.suffix}" buffer.writeline(line) else: var.bounds = var.bounds.tighten(bounds) return var def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" var = V.kernel.create_cse_var(var_name, bounds) self.varname_map[var_name] = var return var class IndirectAssertLine(DeferredLineBase): def __init__(self, line, assert_fn, var, mask, size_map): self.var = var self.mask = mask self.line = line self.assert_fn = assert_fn self.size_map = size_map def __call__(self): size, size_str = self.size_map[(self.var, self.mask)] # We assert if we've not been able to prove the bound assert_min = (self.var.bounds.lower >= 0) != sympy.true assert_max = (self.var.bounds.upper < size) != sympy.true # FooBar interview question if not (assert_min or assert_max): return None elif assert_min and assert_max: # The conditions need to be in parens because of Python's operator precedence. # It'd be less error-prone to use and/or/not, which is suported by triton cond = f"(0 <= {self.var}) & ({self.var} < {size_str})" cond_print = f"0 <= {self.var} < {size_str}" elif assert_min: cond = f"0 <= {self.var}" cond_print = cond else: assert assert_max cond = f"{self.var} < {size_str}" cond_print = cond if self.mask: cond = f"({cond}) | ~{self.mask}" return self.line.format( assert_fn=self.assert_fn, cond=cond, cond_print=cond_print ) def _new_line(self, line): return IndirectAssertLine( line, self.assert_fn, self.var, self.mask, self.size_map ) class CodeGen: def __init__(self): super().__init__() self.exit_stack = contextlib.ExitStack() def __enter__(self): self.exit_stack.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): self.exit_stack.__exit__(exc_type, exc_val, exc_tb) class Kernel(CodeGen): newvar_prefix = "" suffix = "" overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None # TODO: these look dead, but with all the getattr it's hard to tell... load_format: None = None store_format: None = None def __init__(self, args=None, increase_kernel_count=True): super().__init__() if increase_kernel_count: metrics.generated_kernel_count += 1 self.args = args or KernelArgs() self.loads = IndentedBuffer() self.compute = IndentedBuffer() self.stores = IndentedBuffer() self.cse: CSE = CSE(self.newvar_prefix, self.suffix) self.must_keep_buffers = set() self.store_buffer_names = set() self._load_mask = None # set in set_current_node self.current_node = None self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None # Upper bounds for indirect_indexing and their str representation # NB: None, None is never stored in map, but it is the assumed # "not set" value for the dict self.indirect_max_sizes: Dict[ Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]] ] = {} self.removed_buffers = set() self.inplaced_to_remove = set() # key: the buffer to write # value: the buffer to read and whose memory can be reused for # the buffer specified by key self.inplace_update_buffers = dict() # Set minimum number of elements processed per thread. self.min_elem_per_thread = 1 self.kernel_name = None @contextlib.contextmanager def set_current_node(self, node): prior = self.current_node self.current_node = node self.node_to_bounds = node._body.bounds().get_bounds() try: yield finally: self.current_node = prior @contextlib.contextmanager def swap_buffers(self, lb, cb=None, sb=None): if cb is None: cb = lb loads = self.loads compute = self.compute stores = self.stores cse = self.cse self.loads = lb self.compute = cb self.stores = sb self.cse = cse.clone() try: yield finally: self.loads = loads self.compute = compute self.stores = stores self.cse = cse def load(self, name: str, index: sympy.Expr) -> CSEVariable: raise NotImplementedError() def indirect_load(self, name: str, index: sympy.Expr): """A load the depends on an index we have read""" prior = self.loads try: # put the load in the compute section as it might have deps self.loads = self.compute return self.load(name, index) finally: self.loads = prior def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): raise NotImplementedError() def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> None: raise NotImplementedError() def reduction( self, dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: ReductionType, value: Union[CSEVariable, Tuple[CSEVariable, ...]], ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: raise NotImplementedError() def scan( self, dtype: torch.dtype, combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], value: CSEVariable, init: int, ) -> CSEVariable: raise NotImplementedError() def bucketize( self, values: CSEVariable, offsets_name: str, offsets_size: sympy.Expr, indexing_dtype: torch.dtype, right: bool, ) -> CSEVariable: """ See [Note: Inductor bucketize op] """ raise NotImplementedError() @property def assert_function(self) -> str: raise NotImplementedError() def index_to_str(self, index: sympy.Expr) -> str: raise NotImplementedError() def __enter__(self): # TODO: hoist this to top level class CSEProxy: self.name = "CSEProxy" @staticmethod def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): # TritonTemplateKernel has no current_node buf_bounds = ValueRanges.unknown() if hasattr(V.interpreter, "current_node"): fx_node = V.interpreter.current_node assert isinstance(self.node_to_bounds, dict) buf_bounds = self.node_to_bounds.get( fx_node, ValueRanges.unknown() ) value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] def do_cse(v): csevar = self.cse.generate(self.compute, v, bounds=buf_bounds) csevar.update_on_args(name, args, kwargs) return csevar return pytree.tree_map(do_cse, value) return inner @staticmethod def indirect_indexing( var: CSEVariable, size: sympy.Expr, check: bool = True ): # Skip CSE since this doesn't return an expression if var.bounds.lower < 0: # type: ignore[operator] new_bounds = ValueRanges.unknown() if var.bounds != ValueRanges.unknown() and isinstance( size, sympy.Number ): # Take the negative part of the bound and add size to it # Then take union of that and the positive part # This is a tighter bound than that of a generic ops.where, as we have info on the cond neg = var.bounds & ValueRanges(-sympy.oo, -1) new_bounds = ValueRanges(neg.lower + size, neg.upper + size) # We don't have a good way of representing the empty range if var.bounds.upper >= 0: # type: ignore[operator] pos = var.bounds & ValueRanges(0, sympy.oo) new_bounds = new_bounds | pos stm = ops.add(var, self.rename_indexing(size)) # Mixed negative and non-negative if var.bounds.upper >= 0: # type: ignore[operator] lt = ops.lt(var, "0") stm = ops.where(lt, stm, var) new_var = self.cse.generate(self.compute, stm, bounds=new_bounds) new_var.update_on_args("index_wrap", (var,), {}) var = new_var if self.generate_assert(check): mask = self.load_mask(var) # An assertion line may have been written already, if so just # update the max size. map_key = (var, mask) existing_size, _ = self.indirect_max_sizes.get( map_key, (None, None) ) if existing_size is not None: size = sympy.Min(size, existing_size) else: line = ( '{assert_fn}({cond}, "index out of bounds: {cond_print}")' ) self.compute.writeline( IndirectAssertLine( line, self.assert_function, var, mask, self.indirect_max_sizes, ) ) self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) return sympy_index_symbol(str(var)) @staticmethod def load(name: str, index: sympy.Expr) -> CSEVariable: if name in self.cse.invalidated_stores: # A load from an invalidated store requires us to # keep the actual buffer around V.kernel.must_keep_buffers.add(name) if free_symbol_startswith(index, "tmp"): return self.indirect_load(name, index) store_cache = self.cse.store_cache if name in store_cache: return store_cache[name] return self.load(name, index) @staticmethod def store( name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> None: self.store_buffer_names.add(name) if mode is None: self.cse.store_cache[name] = value if self.current_node: for other_name in self.current_node.get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store(name, index, value, mode=mode) else: return None # type: ignore[return-value] @staticmethod def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: for other_name in self.current_node.get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: return self.store_reduction(name, index, value) @staticmethod def reduction( dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: ReductionType, value: Union[CSEVariable, Tuple[CSEVariable, ...]], ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: return self.reduction(dtype, src_dtype, reduction_type, value) @staticmethod def scan( dtype: torch.dtype, combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable], value: CSEVariable, init: int, ) -> CSEVariable: return self.scan(dtype, combine_fn, value, init) @staticmethod def bucketize( values: CSEVariable, offsets_name: str, offsets_size: sympy.Expr, indexing_dtype: torch.dtype, right: bool, ) -> CSEVariable: """ [Note: Inductor bucketize op] Given values (tensor) and offsets_name (reference to the name of a 1D tensor), calculate the bucket that each value belongs to. e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True return = [ 0, 1, 1, 1, 1, 3, 3, 4]. When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. When right == True, bucket i refers to range [offsets[i], offsets[i+1]). Offsets must be non-decreasing or the result is undefined. """ return self.bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) # Use mypy to check protocol implemented correctly def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: return h super().__enter__() assert self.overrides parent_handler = self.overrides(V.get_ops_handler()) self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self def __exit__(self, exc_type, exc_val, exc_tb): """ Note that V.graph.scheduler can be None when codegening triton template kernels. """ if V.graph.scheduler: V.graph.scheduler.remove_kernel_local_buffers() super().__exit__(exc_type, exc_val, exc_tb) def generate_assert(self, check): return (check or config.debug_index_asserts) and config.assert_indirect_indexing def load_mask(self, var) -> str: # only the triton kernel requires mask return "" def rename_indexing(self, index) -> sympy.Expr: # adds the necessary kernel args for index expressions # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): return [self.rename_indexing(x) for x in index] # type: ignore[return-value] index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { x: self.args.size(x) for x in sorted_symbols if x.name.startswith(("s", "u", "ps")) or (x.name.startswith("i") and not x.name.startswith("idx")) } return sympy_subs(index, replacements) def create_cse_var(self, *args, **kwargs): return CSEVariable(*args, **kwargs) @dataclasses.dataclass class OptimizationContext: key: ClassVar[str] = "opt_ctx" # Load value as mask is_load_as_mask: bool = False dtype: Optional[torch.dtype] = None ops_name: str = "" # Load uint8/int8 value as float32 is_load_int8_as_float: bool = False @functools.lru_cache(None) def jinja2_env(): try: import jinja2 return jinja2.Environment( undefined=jinja2.StrictUndefined, ) except ImportError: return None class ChoiceCaller: """ Represents a possible choice used in autotune_process.py. During autotuning, self.benchmark() is first called to get benchmark result, and if this choice is selected, self.output_node() is called to get the output_node. Children classes: TritonTemplateCaller, CUDATemplateCaller. """ def __init__(self, name, input_nodes, layout): super().__init__() self.name = name self.layout = layout self.input_nodes = input_nodes def benchmark(self, *args, out) -> float: algo = self.to_callable() return do_bench(lambda: algo(*args, out=out)) def call_name(self) -> str: raise NotImplementedError() def to_callable(self): raise NotImplementedError() def hash_key(self) -> str: raise NotImplementedError() def output_node(self) -> "TensorBox": raise NotImplementedError() class KernelTemplate: """ Base class for defining kernel templates. Children classes: TritonTemplate, CUDATemplate """ @staticmethod def _template_from_string(source): env = jinja2_env() if env is not None: return env.from_string(source) return None @staticmethod def _fake_get_dtype(fake_out): _get_dtype_real = V.graph.get_dtype def get_dtype(name): if name == fake_out.get_name(): return fake_out.get_dtype() return _get_dtype_real(name) return get_dtype def __init__(self, name: str): self.name = name def maybe_append_choice(self, choices, **kwargs): """ Maybe generates a new ChoiceCaller and appends it into existing choices. choices: A list of ChoiceCallers. kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. """ try: choices.append(self.generate(**kwargs)) except NotImplementedError: pass def generate(self, **kwargs) -> ChoiceCaller: """ Generates a ChoiceCaller instance from the given arguments. """ raise NotImplementedError()