mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			558 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			558 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["module: inductor"]
 | 
						|
 | 
						|
import functools
 | 
						|
import logging
 | 
						|
from typing import Any, Callable, Optional, Union
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import _ops
 | 
						|
from torch._inductor import config
 | 
						|
from torch._inductor.codegen.subgraph import SubgraphTemplate
 | 
						|
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
 | 
						|
from torch._inductor.lowering import lowerings, validate_ir
 | 
						|
from torch._inductor.select_algorithm import (
 | 
						|
    autotune_select_algorithm,
 | 
						|
    ExternKernelChoice,
 | 
						|
)
 | 
						|
from torch._inductor.virtualized import V
 | 
						|
from torch.utils._ordered_set import OrderedSet
 | 
						|
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class CustomOpConfig:
 | 
						|
    """Config for custom op autotuning.
 | 
						|
 | 
						|
    Specifies optional decomposition function with parameter values.
 | 
						|
    Each config creates exactly one variant.
 | 
						|
 | 
						|
    Args:
 | 
						|
        decomposition: Optional functions to autotune. If not provided, default will be used.
 | 
						|
        **params: Parameters passed to the function
 | 
						|
 | 
						|
    Examples:
 | 
						|
        CustomOpConfig(attention_impl, head_dim=32, method='chunked')
 | 
						|
        CustomOpConfig(head_dim=32, method='chunked')
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        decomposition: Optional[Callable[..., Any]] = None,
 | 
						|
        **params: Any,
 | 
						|
    ):
 | 
						|
        if decomposition is not None and not callable(decomposition):
 | 
						|
            raise TypeError(
 | 
						|
                f"decomposition must be callable, got {type(decomposition)}"
 | 
						|
            )
 | 
						|
 | 
						|
        self.decomposition = decomposition
 | 
						|
        self.params = params
 | 
						|
 | 
						|
    def get_decomposition(
 | 
						|
        self, default_impl: Optional[Callable[..., Any]] = None
 | 
						|
    ) -> Callable[..., Any]:
 | 
						|
        """Return the decomposition function for this config.
 | 
						|
        When decomposition is not specified, return the default implementation
 | 
						|
        from the custom op's registration.
 | 
						|
        """
 | 
						|
        if self.decomposition is not None:
 | 
						|
            return self.decomposition
 | 
						|
 | 
						|
        # If no decomposition specified in config, get Python implementation from custom op registration
 | 
						|
        if default_impl and isinstance(default_impl, _ops.OpOverload):
 | 
						|
            from torch._library.custom_ops import _maybe_get_opdef
 | 
						|
 | 
						|
            op_def = _maybe_get_opdef(default_impl)
 | 
						|
            if op_def is not None and hasattr(op_def, "_init_fn"):
 | 
						|
                return op_def._init_fn
 | 
						|
 | 
						|
        raise TypeError(
 | 
						|
            f"Could not extract Python implementation from {default_impl}. "
 | 
						|
            f"Please register customop or provide a decomposition function."
 | 
						|
        )
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        decomp_name = self.decomposition.__name__ if self.decomposition else "default"
 | 
						|
        if self.params:
 | 
						|
            params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
 | 
						|
            return f"CustomOpConfig({decomp_name}, {params_str})"
 | 
						|
        return f"CustomOpConfig({decomp_name})"
 | 
						|
 | 
						|
 | 
						|
class CustomOpConfig:
 | 
						|
    """Config for custom op autotuning - similar to triton.Config.
 | 
						|
 | 
						|
    Specifies decomposition function with parameter values.
 | 
						|
    Each config creates exactly one variant (no Cartesian product).
 | 
						|
 | 
						|
    Args:
 | 
						|
        decomposition: Function to autotune
 | 
						|
        **params: Parameters passed to the function
 | 
						|
 | 
						|
    Examples:
 | 
						|
        CustomOpConfig(attention_impl, head_dim=32, method='chunked')
 | 
						|
        CustomOpConfig(fallback_impl)
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, decomposition: Callable[..., Any], **params: Any):
 | 
						|
        if not callable(decomposition):
 | 
						|
            raise TypeError(
 | 
						|
                f"decomposition must be callable, got {type(decomposition)}"
 | 
						|
            )
 | 
						|
 | 
						|
        self.decomposition = decomposition
 | 
						|
        self.params = params
 | 
						|
 | 
						|
        # Generate descriptive name
 | 
						|
        if self.params:
 | 
						|
            param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
 | 
						|
            self.name = f"{decomposition.__name__}_{param_suffix}"
 | 
						|
        self.name = decomposition.__name__
 | 
						|
 | 
						|
    def create_variant(self) -> Callable[..., Any]:
 | 
						|
        """Create callable with parameters pre-applied using functools.partial."""
 | 
						|
        if self.params:
 | 
						|
            variant = functools.partial(self.decomposition, **self.params)
 | 
						|
            variant.__name__ = self.name  # type: ignore[attr-defined]
 | 
						|
            return variant
 | 
						|
 | 
						|
        return self.decomposition
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        if self.params:
 | 
						|
            params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
 | 
						|
            return f"CustomOpConfig({self.decomposition.__name__}, {params_str})"
 | 
						|
        return f"CustomOpConfig({self.decomposition.__name__})"
 | 
						|
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    "autotune_custom_op",
 | 
						|
    "register_custom_op_autotuning",
 | 
						|
    "CustomOpConfig",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def _extract_tensor_inputs(
 | 
						|
    args: tuple[Any, ...], kwargs: dict[str, Any]
 | 
						|
) -> tuple[list[Any], dict[str, Any]]:
 | 
						|
    """Extract tensor inputs from mixed args/kwargs.
 | 
						|
    Separates tensors (for autotuning input_nodes) from non-tensor parameters.
 | 
						|
    Non-tensor kwargs are later functools.partial'd into decomposition functions.
 | 
						|
 | 
						|
    Args:
 | 
						|
        args: Positional arguments (mix of tensors and scalars)
 | 
						|
        kwargs: Keyword arguments (mix of tensors and scalars)
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tuple of (tensor_inputs_list, non_tensor_kwargs)
 | 
						|
    """
 | 
						|
    tensor_inputs = []
 | 
						|
    non_tensor_kwargs = {}
 | 
						|
 | 
						|
    # Process args and kwargs: separate tensor inputs and non tensor args
 | 
						|
    for i, arg in enumerate(args):
 | 
						|
        if isinstance(arg, (TensorBox, Buffer)):
 | 
						|
            tensor_inputs.append(arg)
 | 
						|
        else:
 | 
						|
            # Add non-tensor positional args to kwargs with generated names
 | 
						|
            non_tensor_kwargs[f"arg_{i}"] = arg
 | 
						|
 | 
						|
    for key, value in kwargs.items():
 | 
						|
        if isinstance(value, (TensorBox, Buffer)):
 | 
						|
            tensor_inputs.append(value)
 | 
						|
        else:
 | 
						|
            non_tensor_kwargs[key] = value
 | 
						|
 | 
						|
    return tensor_inputs, non_tensor_kwargs
 | 
						|
 | 
						|
 | 
						|
def _merge_config_and_runtime_kwargs(
 | 
						|
    config_params: dict[str, Any],
 | 
						|
    runtime_kwargs: dict[str, Any],
 | 
						|
) -> dict[str, Any]:
 | 
						|
    """Merge config parameters with runtime kwargs. Runtime kwargs take precedence.
 | 
						|
       If there are conflicts, log a warning and use runtime value.
 | 
						|
 | 
						|
    Args:
 | 
						|
        config_params: Parameters from CustomOpConfig
 | 
						|
        runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Merged kwargs dictionary with runtime values taking precedence
 | 
						|
    """
 | 
						|
    merged_kwargs = config_params.copy()
 | 
						|
 | 
						|
    # Check for conflicts and let runtime kwargs dominate
 | 
						|
    conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys())
 | 
						|
 | 
						|
    for key in conflicts:
 | 
						|
        log.warning(
 | 
						|
            "Parameter '%s' specified both in CustomOpConfig (%s) "
 | 
						|
            "and at runtime (%s). Using runtime value.",
 | 
						|
            key,
 | 
						|
            config_params[key],
 | 
						|
            runtime_kwargs[key],
 | 
						|
        )
 | 
						|
 | 
						|
    # Runtime kwargs override config params
 | 
						|
    merged_kwargs.update(runtime_kwargs)
 | 
						|
 | 
						|
    return merged_kwargs
 | 
						|
 | 
						|
 | 
						|
def _adapt_user_input_gen_fns(
 | 
						|
    inputs: list[Any],
 | 
						|
    arg_names: list[str],
 | 
						|
    user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
 | 
						|
) -> dict[int, Callable[[Any], torch.Tensor]]:
 | 
						|
    """Convert user input generators from name-based to index-based format.
 | 
						|
       Inductor autotune's input_gen_fns expects index of arg_names as key.
 | 
						|
 | 
						|
    Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
 | 
						|
    """
 | 
						|
 | 
						|
    name_to_index = {name: i for i, name in enumerate(arg_names)}
 | 
						|
    index_based_fns = {}
 | 
						|
 | 
						|
    for name, gen_fn in user_input_gen_fns.items():
 | 
						|
        if name in name_to_index:
 | 
						|
            index_based_fns[name_to_index[name]] = gen_fn
 | 
						|
        else:
 | 
						|
            log.warning(
 | 
						|
                "Unknown argument name '%s' in input_gen_fns. "
 | 
						|
                "Available argument names: %s",
 | 
						|
                name,
 | 
						|
                list(name_to_index.keys()),
 | 
						|
            )
 | 
						|
 | 
						|
    def create_internal_input_gen_fn(
 | 
						|
        user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
 | 
						|
    ) -> Callable[[Any], torch.Tensor]:
 | 
						|
        """Create internal input generator that converts IR buffer to user's fake tensor."""
 | 
						|
 | 
						|
        def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
 | 
						|
            raw_shape = ir_buffer.get_size()
 | 
						|
            concrete_shape = V.graph.sizevars.size_hints(
 | 
						|
                raw_shape, fallback=config.unbacked_symint_fallback
 | 
						|
            )
 | 
						|
 | 
						|
            fake_tensor = torch.empty(
 | 
						|
                concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
 | 
						|
            )
 | 
						|
            return user_function(fake_tensor)
 | 
						|
 | 
						|
        return internal_input_gen_fn
 | 
						|
 | 
						|
    return {
 | 
						|
        i: create_internal_input_gen_fn(
 | 
						|
            user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
 | 
						|
        )
 | 
						|
        for i, user_gen_fn in index_based_fns.items()
 | 
						|
        if i < len(inputs)
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def _create_fallback_choice(
 | 
						|
    name: str,
 | 
						|
    default_impl: Callable[..., Any],
 | 
						|
    fake_output: torch.Tensor,
 | 
						|
    kwargs: dict[str, Any],
 | 
						|
) -> ExternKernelChoice:
 | 
						|
    """Create fallback choice for default implementation."""
 | 
						|
 | 
						|
    def fallback_wrapper(*args: Any) -> Any:
 | 
						|
        return default_impl(*args, **kwargs)
 | 
						|
 | 
						|
    return ExternKernelChoice(
 | 
						|
        kernel=fallback_wrapper,
 | 
						|
        name=f"{name}_fallback_default",
 | 
						|
        has_out_variant=False,
 | 
						|
        op_overload=default_impl,
 | 
						|
        use_fallback_kernel=True,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def autotune_custom_op(
 | 
						|
    name: str,
 | 
						|
    decompositions: list[Callable[..., Any]],
 | 
						|
    inputs: list[Any],
 | 
						|
    non_tensor_args: list[dict[str, Any]],
 | 
						|
    default_impl: Optional[Callable[..., Any]] = None,
 | 
						|
    user_input_gen_fns: Optional[
 | 
						|
        dict[str, Callable[[torch.Tensor], torch.Tensor]]
 | 
						|
    ] = None,
 | 
						|
    enable_epilogue_fusion: bool = False,
 | 
						|
    enable_prologue_fusion: bool = False,
 | 
						|
    disable_fallback: bool = False,
 | 
						|
) -> Union[TensorBox, Any]:
 | 
						|
    """Autotune custom operations by comparing multiple decomposition implementations.
 | 
						|
 | 
						|
    Currently supports SINGLE OUTPUT custom ops only.
 | 
						|
    TODO: Add support for multiple output custom ops (tuple/list returns).
 | 
						|
 | 
						|
    This function generates multiple implementation choices for a custom operation and
 | 
						|
    uses Inductor's autotuning system to select the best performing variant at runtime.
 | 
						|
    After selecting the best choice, optionally applies inline epilogue fusion.
 | 
						|
 | 
						|
    Args:
 | 
						|
        name: Unique identifier for the autotuning operation
 | 
						|
        decompositions: List of alternative implementation functions to benchmark
 | 
						|
        inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
 | 
						|
        non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg
 | 
						|
        default_impl: Original custom op implementation used as fallback
 | 
						|
        user_input_gen_fns: Optional custom input generators for benchmarking.
 | 
						|
                           Maps input indices to functions that take fake tensors
 | 
						|
                           and return real tensors for performance measurement.
 | 
						|
        enable_epilogue_fusion: If True, apply inline epilogue fusion to the best choice
 | 
						|
 | 
						|
    Returns:
 | 
						|
        IR node representing the optimized operation result
 | 
						|
 | 
						|
    Raises:
 | 
						|
        TypeError: If decompositions is not a list/tuple
 | 
						|
        RuntimeError: If no inputs or no valid choices generated
 | 
						|
    """
 | 
						|
    if not isinstance(decompositions, (list, tuple)):
 | 
						|
        raise TypeError(
 | 
						|
            f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
 | 
						|
        )
 | 
						|
 | 
						|
    if not inputs:
 | 
						|
        raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
 | 
						|
 | 
						|
    if len(decompositions) != len(non_tensor_args):
 | 
						|
        raise ValueError(
 | 
						|
            f"decompositions and non_tensor_args must have same length, "
 | 
						|
            f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
 | 
						|
        )
 | 
						|
 | 
						|
    template = SubgraphTemplate(name=name)
 | 
						|
    choices = template.generate_custom_op_choices(
 | 
						|
        name=name,
 | 
						|
        decompositions=decompositions,
 | 
						|
        input_nodes=list(inputs),
 | 
						|
        non_tensor_args=non_tensor_args,
 | 
						|
    )
 | 
						|
 | 
						|
    # Add default implementation as fallback (unless disabled)
 | 
						|
    if default_impl and hasattr(default_impl, "_op") and not disable_fallback:
 | 
						|
        fallback_name = f"{name}_fallback_default"
 | 
						|
        from torch._inductor.select_algorithm import extern_kernels
 | 
						|
 | 
						|
        # Skip if extern_kernel already registered to avoid duplicate registration error
 | 
						|
        if not hasattr(extern_kernels, fallback_name):
 | 
						|
            with V.fake_mode:
 | 
						|
                fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
 | 
						|
                fallback_kwargs = non_tensor_args[0] if non_tensor_args else {}
 | 
						|
                fake_output = default_impl(*fake_inputs, **fallback_kwargs)
 | 
						|
 | 
						|
            fallback_choice = _create_fallback_choice(
 | 
						|
                name, default_impl, fake_output, fallback_kwargs
 | 
						|
            )
 | 
						|
            fallback_choice.maybe_append_choice(
 | 
						|
                choices=choices,
 | 
						|
                input_nodes=list(inputs),
 | 
						|
                layout=FixedLayout(
 | 
						|
                    device=fake_output.device,
 | 
						|
                    dtype=fake_output.dtype,
 | 
						|
                    size=fake_output.shape,
 | 
						|
                    stride=fake_output.stride(),
 | 
						|
                ),
 | 
						|
            )
 | 
						|
 | 
						|
    if not choices:
 | 
						|
        raise RuntimeError(f"No valid choices generated for {name}")
 | 
						|
 | 
						|
    # Convert user input generation functions to internal format
 | 
						|
    input_gen_fns = {}
 | 
						|
    if user_input_gen_fns:
 | 
						|
        import inspect
 | 
						|
 | 
						|
        arg_names = (
 | 
						|
            list(inspect.signature(decompositions[0]).parameters.keys())
 | 
						|
            if decompositions
 | 
						|
            else []
 | 
						|
        )
 | 
						|
        input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)
 | 
						|
 | 
						|
    # Run autotuning to select the best choice
 | 
						|
    selected_result = autotune_select_algorithm(
 | 
						|
        name=name,
 | 
						|
        choices=choices,
 | 
						|
        input_nodes=list(inputs),
 | 
						|
        layout=choices[0].layout,
 | 
						|
        input_gen_fns=input_gen_fns,
 | 
						|
    )
 | 
						|
 | 
						|
    # Apply inlining if epilogue fusion is enabled
 | 
						|
    if enable_epilogue_fusion and isinstance(selected_result, TensorBox):
 | 
						|
        winning_choice = choices[0]  # TODO: fix use selected choice instead of 0
 | 
						|
        inlined_result = _inline_custom_op_choice(winning_choice, inputs, name)
 | 
						|
        return inlined_result
 | 
						|
 | 
						|
    return selected_result
 | 
						|
 | 
						|
 | 
						|
def _inline_custom_op_choice(
 | 
						|
    winning_choice: Any, inputs: list[Any], name: str
 | 
						|
) -> TensorBox:
 | 
						|
    """Inline the winning custom op choice by converting its FX operations to individual IR nodes.
 | 
						|
 | 
						|
    This converts the custom op from a single ExternKernel (unfusable) to multiple ComputedBuffer
 | 
						|
    nodes (fusable), enabling epilogue fusion with subsequent operations.
 | 
						|
 | 
						|
    Args:
 | 
						|
        winning_choice: The winning SubgraphChoiceCaller from autotuning
 | 
						|
        inputs: Original input nodes
 | 
						|
        name: Custom op name for debugging
 | 
						|
 | 
						|
    Returns:
 | 
						|
        TensorBox containing the final operation result as individual IR nodes
 | 
						|
    """
 | 
						|
    from torch._inductor.lowering import lowerings
 | 
						|
 | 
						|
    # Get the GraphModule containing the operations
 | 
						|
    gm = winning_choice.gm
 | 
						|
 | 
						|
    # Create mapping from placeholder nodes to actual inputs
 | 
						|
    node_to_value = {}
 | 
						|
    placeholder_idx = 0
 | 
						|
 | 
						|
    # Process each node in the winning choice's graph
 | 
						|
    for node in gm.graph.nodes:
 | 
						|
        if node.op == "placeholder":
 | 
						|
            # Map placeholder to actual input
 | 
						|
            if placeholder_idx < len(inputs):
 | 
						|
                node_to_value[node] = inputs[placeholder_idx]
 | 
						|
                placeholder_idx += 1
 | 
						|
            else:
 | 
						|
                raise RuntimeError(f"Not enough inputs for placeholder {node.name}")
 | 
						|
 | 
						|
        elif node.op == "call_function":
 | 
						|
            # Convert FX operation to IR nodes using existing lowerings
 | 
						|
            target = node.target
 | 
						|
            args = [
 | 
						|
                node_to_value[arg] if arg in node_to_value else arg for arg in node.args
 | 
						|
            ]
 | 
						|
            kwargs = {
 | 
						|
                k: node_to_value[v] if v in node_to_value else v
 | 
						|
                for k, v in node.kwargs.items()
 | 
						|
            }
 | 
						|
 | 
						|
            # Call the appropriate lowering function
 | 
						|
            if target in lowerings:
 | 
						|
                result = lowerings[target](*args, **kwargs)
 | 
						|
                node_to_value[node] = result
 | 
						|
            else:
 | 
						|
                # Fallback: try calling the target directly
 | 
						|
                result = target(*args, **kwargs)
 | 
						|
                node_to_value[node] = result
 | 
						|
 | 
						|
        elif node.op == "output":
 | 
						|
            # Return the final result
 | 
						|
            output_arg = node.args[0]
 | 
						|
            if isinstance(output_arg, (list, tuple)):
 | 
						|
                # Multi-output case (not yet supported)
 | 
						|
                raise RuntimeError(
 | 
						|
                    "Multi-output custom ops not yet supported for inlining"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                # Single output case
 | 
						|
                final_result = node_to_value[output_arg]
 | 
						|
                return final_result
 | 
						|
 | 
						|
        else:
 | 
						|
            raise RuntimeError(f"Unsupported node type: {node.op}")
 | 
						|
 | 
						|
    raise RuntimeError("No output node found in custom op graph")
 | 
						|
 | 
						|
 | 
						|
def register_custom_op_autotuning(
 | 
						|
    custom_op: torch._ops.OpOverload,
 | 
						|
    configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
 | 
						|
    name: Optional[str] = None,
 | 
						|
    input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
 | 
						|
    enable_epilogue_fusion: bool = False,
 | 
						|
    enable_prologue_fusion: bool = False,
 | 
						|
    disable_fallback: bool = False,
 | 
						|
) -> None:
 | 
						|
    """Register custom op for autotuning with custom_op configs where each config
 | 
						|
    specifies a decomposition implementation function with its parameter values.
 | 
						|
 | 
						|
    Args:
 | 
						|
        custom_op: Custom operation to register
 | 
						|
        configs: List of CustomOpConfig objects or callable functions
 | 
						|
        name: Operation name (default: "{op_name}_autotuned")
 | 
						|
        input_gen_fns: Custom input generators for benchmarking
 | 
						|
 | 
						|
    Examples:
 | 
						|
        register_custom_op_autotuning(
 | 
						|
            torch.ops.mylib.attention.default,
 | 
						|
            configs=[
 | 
						|
                CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
 | 
						|
                CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
 | 
						|
                CustomOpConfig(fallback_impl),  # No params
 | 
						|
            ],
 | 
						|
            input_gen_fns={
 | 
						|
                "query": lambda fake: torch.randn_like(fake, device='cuda'),
 | 
						|
                "key": lambda fake: torch.randn_like(fake, device='cuda'),
 | 
						|
                "value": lambda fake: torch.randn_like(fake, device='cuda'),
 | 
						|
            }
 | 
						|
        )
 | 
						|
    """
 | 
						|
    if not isinstance(configs, (list, tuple)):
 | 
						|
        raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
 | 
						|
 | 
						|
    processed_configs = []
 | 
						|
    for config in configs:
 | 
						|
        if isinstance(config, CustomOpConfig):
 | 
						|
            processed_configs.append(config)
 | 
						|
        else:
 | 
						|
            raise TypeError(
 | 
						|
                f"Each config must be a CustomOpConfig object, got {type(config)}"
 | 
						|
            )
 | 
						|
 | 
						|
    if not processed_configs:
 | 
						|
        raise ValueError("At least one config must be provided")
 | 
						|
 | 
						|
    if name is None:
 | 
						|
        name = f"{custom_op._name}_autotuned"
 | 
						|
 | 
						|
    @functools.wraps(custom_op)
 | 
						|
    def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
 | 
						|
        """Inductor lowering function that replaces custom op calls with autotuned versions."""
 | 
						|
        # Extract tensor inputs and non-tensor parameters (runtime kwargs)
 | 
						|
        tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
 | 
						|
 | 
						|
        # Prepare decompositions and kwargs by merging customop config params with runtime kwargs
 | 
						|
        decompositions = []
 | 
						|
        non_tensor_args = []
 | 
						|
 | 
						|
        for config in processed_configs:
 | 
						|
            decomp = config.get_decomposition(default_impl=custom_op)
 | 
						|
            decompositions.append(decomp)
 | 
						|
 | 
						|
            # Merge config params with runtime kwargs (runtime takes precedence)
 | 
						|
            merged_kwargs = _merge_config_and_runtime_kwargs(
 | 
						|
                config.params, runtime_kwargs
 | 
						|
            )
 | 
						|
            non_tensor_args.append(merged_kwargs)
 | 
						|
 | 
						|
        result = autotune_custom_op(
 | 
						|
            name=name,
 | 
						|
            decompositions=decompositions,
 | 
						|
            inputs=tensor_inputs,
 | 
						|
            non_tensor_args=non_tensor_args,
 | 
						|
            default_impl=custom_op,
 | 
						|
            user_input_gen_fns=input_gen_fns,
 | 
						|
            enable_epilogue_fusion=enable_epilogue_fusion,
 | 
						|
            enable_prologue_fusion=enable_prologue_fusion,
 | 
						|
            disable_fallback=disable_fallback,
 | 
						|
        )
 | 
						|
 | 
						|
        validate_ir(result)
 | 
						|
        return result
 | 
						|
 | 
						|
    lowerings[custom_op] = autotuning_lowering
 |