Compare commits

...

23 Commits

Author SHA1 Message Date
c992543eca param suffix and test refine 2025-10-20 16:04:12 -07:00
1868744409 update API morepscificiely and use config 2025-10-20 15:36:34 -07:00
06890019d2 resolve comments:fix input_gen_fns API; tensor shape infer with siz_hints instead of ir_node_to_tensor; seperated non_tensor_args for Functools.partial for kwargs ; added Single output - layout assertion; lint 2025-10-16 00:48:02 -07:00
8daee127db simplify API and improve code readability 2025-10-13 22:15:50 -07:00
58590fb37f fix cpu tests and lint 2025-10-12 23:52:14 -07:00
69688d49a9 simplify logic 2025-10-12 17:30:52 -07:00
584bd31a10 support custom op tuning for parameters; passing tuning knob as an arg 2025-10-12 17:05:05 -07:00
1c77a09da5 resolve input_gen_fn call and register faketensor input 2025-10-11 22:57:08 -07:00
60cd4b5730 remove input_gen_fns from user's interface; moving to inductor internal 2025-10-08 22:39:09 -07:00
fd6938766a revert changes and clean up code 2025-10-06 21:58:25 -07:00
89283b4fb9 Tianren/custom op autotune fix (#164689)
* remove redundant catch and callable list; refine decorator to avoid using register_lowering; fix document

* wip fallback to default

* fix lint

* refine test case

* include default implemenation to the choices

* ensure test passed for default implementation correctly; fix lint

* clean up test

* refine test structure

* refine test, fix lint, remove new tempalte

* clean up code and refine

* clean up code

* simply code and lint
2025-10-06 21:50:06 -07:00
13bedfdfd3 fix lint issue and deal with device properly 2025-10-02 10:46:55 -07:00
ea6d1ff025 add fallback and seperate cpu and gpu backend 2025-10-02 00:33:26 -07:00
579ff95850 fix lint issues 2025-09-30 23:47:34 -07:00
6efa559a0e fix lint 2025-09-30 22:24:19 -07:00
aae722c5a8 linter fix 2025-09-30 17:27:12 -07:00
ffc17077c9 lint 2025-09-30 15:55:08 -07:00
75bf74d926 refine tests and update customop 2025-09-30 15:55:08 -07:00
ce751dcb45 refine test for rmsnorm variants and custom op, clean up code, test passed for 3 decompositions 2025-09-30 15:55:08 -07:00
46759ac0d2 update test 2025-09-30 15:55:08 -07:00
7206224dc8 add test 2025-09-30 15:55:08 -07:00
807e35f76c add tests 2025-09-30 15:55:08 -07:00
9303113015 initial tests and modified custom op kernel template 2025-09-30 15:55:08 -07:00
4 changed files with 1127 additions and 3 deletions

View File

@ -0,0 +1,572 @@
# Owner(s): ["module: inductor"]
"""
Tests for custom operation autotuning with PyTorch Inductor.
Users can register custom ops with multiple decomposition implementations and let
Inductor automatically select the best performing variant. Key features tested:
- Name-based input generators (use argument names instead of indices)
- Dynamic shape handling across multiple compilations
- Parametric tuning with tuning_knob for combinatorial parameter exploration
- Numerical correctness and performance validation
"""
import torch
from torch._inductor import config
from torch._inductor.kernel.custom_op import (
CustomOpConfig,
register_custom_op_autotuning,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
torch.set_float32_matmul_precision("high")
class TestCustomOpAutoTune(TestCase):
"""Test custom operation autotuning functionality."""
def setUp(self) -> None:
"""Set up test environment with appropriate device and dtype."""
super().setUp()
self.device = "cuda" if HAS_GPU else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
def _create_test_configs(self):
"""Create common test configurations for different sizes."""
return [
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
]
def _run_autotune_test(self, op_object, inputs, expected, test_name):
"""Shared test infrastructure for autotuning tests."""
@torch.compile
def test_model(*args):
return op_object(*args)
torch._dynamo.reset()
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
with config.patch(
max_autotune=True,
max_autotune_gemm_backends=autotune_backends,
fx_graph_cache=False,
benchmark_kernel=True,
):
compiled_result = test_model(*inputs)
self.assertEqual(
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
)
torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
msg=f"{test_name} numerical mismatch",
)
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
"""Utility to assert that all implementations produce equivalent results."""
implementations = [(func.__name__, func) for func in decompositions]
results = {}
for name, impl in implementations:
result = impl(*inputs)
results[name] = result
# Basic sanity checks
self.assertTrue(
torch.isfinite(result).all(),
f"{op_name} {name} produced non-finite values",
)
# Verify numerical equivalence
reference_name, reference_result = next(iter(results.items()))
for name, result in results.items():
if name != reference_name:
rtol = 1e-1 if "Approximated" in name else 1e-2
atol = 1e-1 if "Approximated" in name else 1e-2
torch.testing.assert_close(
result,
reference_result,
rtol=rtol,
atol=atol,
msg=f"{op_name} {name} differs from {reference_name}",
)
def _create_rmsnorm_inputs(self, batch_size=8, seq_len=1024, hidden_dim=512):
"""Create test inputs for RMSNorm operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
return input_tensor, weight
def _create_mlp_inputs(
self,
batch_size=2,
seq_len=32,
hidden_dim=512,
intermediate_dim=1024,
output_dim=256,
):
"""Create test inputs for MLP operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
gate_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
up_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
down_weight = torch.randn(
intermediate_dim,
output_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
return input_tensor, gate_weight, up_weight, down_weight
@skipIfXpu
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
test_op_name = f"test_lib::rmsnorm_{id(self)}"
def rmsnorm_decomposition1(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Variance-based approach: compute variance then rsqrt."""
variance = x.pow(2).mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
return x * rstd * weight
def rmsnorm_decomposition2(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""vLLM-style RMSNorm implementation - variance computation first approach."""
x_var = x # In vLLM, this could be sliced for variance_size_override
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
if weight is not None:
x = x * weight
return x
def rmsnorm_decomposition3(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""vLLM-style RMSNorm with extended variance computation pattern."""
x_squared = x.pow(2)
variance = x_squared.mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
normalized = x * rstd
# Apply weight scaling
if weight is not None:
normalized = normalized * weight
return normalized
@torch.library.custom_op(test_op_name, mutates_args=())
def test_rmsnorm_op(
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
return torch.nn.functional.rms_norm(
input_tensor, input_tensor.shape[-1:], weight, eps=eps
)
@test_rmsnorm_op.register_fake
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
return torch.empty_like(input_tensor)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
decompositions = [
rmsnorm_decomposition1,
rmsnorm_decomposition2,
rmsnorm_decomposition3,
]
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(rmsnorm_decomposition1) for decomp in decompositions
],
name="test_rmsnorm_autotuned",
input_gen_fns={
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
"weight": lambda weight: torch.ones_like(weight, device=self.device),
},
)
# Test multiple shapes to verify dynamic shape handling
test_shapes = [(2, 16, 128), (8, 32, 256)]
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
test_op_name = f"test_lib::mlp_{id(self)}"
def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""
if method == 0:
# Separate matmuls: standard implementation with torch.matmul
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
elif method == 1:
# Batched approach: uses torch.mm with reshaped tensors
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]
input_2d = input_tensor.view(-1, hidden_dim)
gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)
gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)
return output_2d.view(*batch_shape, output_dim)
elif method == 2:
# Fused weights approach: concatenate then split weights
# Concatenate gate and up weights for one matrix multiply
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
fused_proj = torch.matmul(input_tensor, fused_weight)
intermediate_dim = gate_weight.shape[1]
gate_proj, up_proj = fused_proj.split(
[intermediate_dim, intermediate_dim], dim=-1
)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Use explicit configs with method parameter as tuning knob
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(mlp_variants, method=1), # Batched approach
CustomOpConfig(mlp_variants, method=2), # Fused weights
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
# Create test inputs using the original helper method
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
for method in [1, 2]:
result = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
torch.testing.assert_close(
result,
expected,
rtol=1e-5,
atol=1e-5,
msg=f"Method {method} not equivalent to method 0",
)
# Test autotuning - all should be mathematically equivalent
self._run_autotune_test(
op_object,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
return a, b
@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parameter tuning for k_splits values."""
test_op_name = f"test_lib::decompose_k_{id(self)}"
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - parameter-tuned implementation."""
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
k_parts = k // k_splits
B = k_splits
a_reshaped = torch.permute(
a.reshape(m, B, k_parts), (1, 0, 2)
) # [B, m, k_parts]
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
return torch.sum(result, dim=0) # [m, n]
@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
return decompose_k_implementation(a, b, k_splits)
@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Use parameter tuning to test different k_splits values
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(decompose_k_implementation, k_splits=2),
CustomOpConfig(decompose_k_implementation, k_splits=32),
CustomOpConfig(decompose_k_implementation, k_splits=64),
CustomOpConfig(decompose_k_implementation, k_splits=128),
CustomOpConfig(decompose_k_implementation, k_splits=256),
],
name="test_decompose_k_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix A
"b": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix B
},
)
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
@skipIfXpu
def test_multi_parameter_tuning(self):
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
op_name = f"test_lib::multi_param_{id(self)}"
def multi_param_scaling(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
"""Different scaling approaches controlled by scale_mode parameter."""
if scale_mode == 1:
# Simple broadcasting
return x * factor
elif scale_mode == 2:
# Process in chunks
batch_size, seq_len = x.shape[:2]
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end]
chunks.append(chunk * factor)
return torch.cat(chunks, dim=1)
elif scale_mode == 3:
# Using einsum for scaling
return torch.einsum("...i,i->...i", x, factor)
@torch.library.custom_op(op_name, mutates_args=())
def multi_param_op(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
return multi_param_scaling(x, factor, scale_mode, chunk_size)
@multi_param_op.register_fake
def _(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
):
return torch.empty_like(x)
lib_name, op_suffix = op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(multi_param_scaling, scale_mode=1), # Broadcast
CustomOpConfig(
multi_param_scaling, scale_mode=2, chunk_size=16
), # Chunked 16
CustomOpConfig(
multi_param_scaling, scale_mode=2, chunk_size=32
), # Chunked 32
CustomOpConfig(multi_param_scaling, scale_mode=3), # Einsum
],
name="multi_param_autotuned",
input_gen_fns={
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"factor": lambda t: torch.ones(
t.shape[-1], device=self.device, dtype=t.dtype
),
},
)
# Create test inputs
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
# Verify numerical equivalence across all approaches
expected_result = test_x * test_factor
# Test each scale_mode variant
configs = [
(1, 16), # broadcast, chunk_size ignored
(2, 16), # chunked with size 16
(2, 32), # chunked with size 32
(3, 16), # einsum, chunk_size ignored
]
for i, (scale_mode, chunk_size) in enumerate(configs):
result = multi_param_scaling(
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
)
torch.testing.assert_close(
result,
expected_result,
rtol=1e-5,
atol=1e-5,
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
)
# Test autotuning
self._run_autotune_test(
op_object, (test_x, test_factor), expected_result, "MultiParam"
)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,6 @@
import itertools
import logging
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union
import torch
import torch._inductor.config as config
@ -8,6 +8,7 @@ from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Buffer,
FixedLayout,
get_free_symbols,
get_symbolic_inputs,
gm_original_output_strides,
@ -120,7 +121,12 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
# Use appropriate benchmarker based on layout device type
if self.layout.device.type == "cpu":
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
else:
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
def hash_key(self) -> str:
return "-".join(
@ -207,3 +213,152 @@ class SubgraphTemplate(KernelTemplate):
description=description,
make_fx_graph=make_fx_graph,
)
def generate_custom_op_choices(
self,
name: str,
decompositions: list[Callable[..., Any]],
input_nodes: list[Buffer],
kwargs: Optional[dict[str, Any]] = None,
default_impl: Optional[Callable[..., Any]] = None,
) -> list[SubgraphChoiceCaller]:
"""
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
This method extends SubgraphTemplate to support custom op decompositions,
allowing multiple implementations to compete in autotuning.
Args:
name: Base name for the choices
decompositions: List of decomposition functions to compare
input_nodes: Input nodes for the operation
kwargs: Additional arguments for decomposition functions
default_impl: Default implementation for layout inference
Returns:
List of SubgraphChoiceCaller instances for autotuning
"""
if not decompositions:
return []
kwargs = kwargs or {}
# Infer layouts and ensure stride consistency for fair autotuning comparison
layouts = [
self._infer_custom_op_layout(input_nodes, [decomp], kwargs, default_impl)
for decomp in decompositions
]
self._validate_stride_consistency(name, decompositions, layouts)
# Assert single output layout - assumes custom ops have one output tensor
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
assert all(
layout.device == layouts[0].device
and layout.dtype == layouts[0].dtype
and layout.size == layouts[0].size
for layout in layouts
), f"All decompositions for '{name}' must produce equivalent output layouts"
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
choices = []
for decomp in decompositions:
# Create make_fx_graph function for this decomposition
def make_fx_graph(*args: Any, decomp: Callable[..., Any] = decomp) -> Any:
import functools
from torch.fx.experimental.proxy_tensor import make_fx
# Ensure kwargs is not None for unpacking
decomp_kwargs = kwargs if kwargs is not None else {}
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
choice = self.generate(
name=f"{name}_{decomp.__name__}",
input_nodes=input_nodes,
layout=layout,
make_fx_graph=make_fx_graph,
description=f"CustomOp {decomp.__name__}",
)
choices.append(choice)
return choices
def _validate_stride_consistency(
self,
op_name: str,
decompositions: list[Callable[..., Any]],
layouts: list[Layout],
) -> None:
"""Ensure all decompositions produce compatible strides for fair autotuning."""
if not layouts:
return
strides = [layout.stride for layout in layouts]
reference = strides[0]
for i, stride in enumerate(strides[1:]):
if stride != reference:
raise AssertionError(
f"Stride mismatch in custom op '{op_name}' autotuning: "
f"'{decompositions[i].__name__}' produces stride {stride}, "
f"but '{decompositions[0].__name__}' produces {reference}. "
f"All decompositions must have identical output strides."
)
def _infer_custom_op_layout(
self,
input_nodes: list[Buffer],
decompositions: list[Callable[..., Any]],
kwargs: dict[str, Any],
default_impl: Optional[Callable[..., Any]] = None,
) -> Layout:
"""Infer output layout for custom ops using the default implementation when available.
Note that the Subgraph assumes custom ops return exactly one tensor so far.
TODO: Add support for multiple output custom ops.
"""
import functools
from torch._inductor.virtualized import V
# Assert kwargs contain only non-tensor arguments for functools.partial
for key, value in kwargs.items():
assert not isinstance(value, (torch.Tensor, Buffer)), (
f"kwargs['{key}'] contains tensor {type(value)}. "
f"Tensor arguments should be in input_nodes, not kwargs. "
f"Only scalar/non-tensor parameters should be in kwargs."
)
# Use default_impl if available, otherwise use first decomposition
impl = default_impl if default_impl is not None else decompositions[0]
with V.fake_mode:
example_inputs = []
for inp in input_nodes:
raw_shape = inp.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
)
example_inputs.append(fake_tensor)
fn = functools.partial(
impl, **kwargs
) # kwargs must be non-tensor for partial
output = fn(*example_inputs)
# Assert single output
assert isinstance(output, torch.Tensor), (
f"Expected single tensor output, got {type(output)}. "
f"Multi-output custom ops not yet supported in autotuning."
)
return FixedLayout(
device=output.device,
dtype=output.dtype,
size=output.shape,
stride=output.stride(),
)

View File

@ -0,0 +1,397 @@
# Owner(s): ["module: inductor"]
import functools
from typing import Any, Callable, Optional, Union
import torch
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
from torch._inductor.lowering import lowerings, validate_ir
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
)
from torch._inductor.virtualized import V
class CustomOpConfig:
"""Config for custom op autotuning - similar to triton.Config.
Specifies decomposition function with parameter values.
Each config creates exactly one variant (no Cartesian product).
Args:
decomposition: Function to autotune
**params: Parameters passed to the function
Examples:
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
CustomOpConfig(fallback_impl)
"""
def __init__(self, decomposition: Callable[..., Any], **params: Any):
if not callable(decomposition):
raise TypeError(
f"decomposition must be callable, got {type(decomposition)}"
)
self.decomposition = decomposition
self.params = params
# Generate descriptive name
if self.params:
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
self.name = f"{decomposition.__name__}_{param_suffix}"
else:
self.name = decomposition.__name__
def create_variant(self) -> Callable[..., Any]:
"""Create callable with parameters pre-applied using functools.partial."""
if self.params:
variant = functools.partial(self.decomposition, **self.params)
variant.__name__ = self.name # type: ignore[attr-defined]
return variant
return self.decomposition
def __repr__(self) -> str:
if self.params:
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
return f"CustomOpConfig({self.decomposition.__name__}, {params_str})"
return f"CustomOpConfig({self.decomposition.__name__})"
__all__ = [
"autotune_custom_op",
"register_custom_op_autotuning",
"CustomOpConfig",
]
def _extract_tensor_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Extract tensor inputs from mixed args/kwargs.
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
Non-tensor kwargs are later functools.partial'd into decomposition functions.
Args:
args: Positional arguments (mix of tensors and scalars)
kwargs: Keyword arguments (mix of tensors and scalars)
Returns:
Tuple of (tensor_inputs_list, non_tensor_kwargs)
"""
tensor_inputs = []
non_tensor_kwargs = {}
# Process args and kwargs: separate tensor inputs and non tensor args
for i, arg in enumerate(args):
if isinstance(arg, (TensorBox, Buffer)):
tensor_inputs.append(arg)
else:
# Add non-tensor positional args to kwargs with generated names
non_tensor_kwargs[f"arg_{i}"] = arg
for key, value in kwargs.items():
if isinstance(value, (TensorBox, Buffer)):
tensor_inputs.append(value)
else:
non_tensor_kwargs[key] = value
return tensor_inputs, non_tensor_kwargs
def _create_user_input_gen_fns(
inputs: list[Any],
arg_names: list[str],
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
) -> dict[int, Callable[[Any], torch.Tensor]]:
"""Convert user input generators from name-based to index-based format.
Inductor autotune's input_gen_fns expects index of arg_names as key.
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
from torch._inductor import config
name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
for name, gen_fn in user_input_gen_fns.items():
if name in name_to_index:
index_based_fns[name_to_index[name]] = gen_fn
else:
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
def create_internal_input_gen_fn(
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
) -> Callable[[Any], torch.Tensor]:
"""Create internal input generator that converts IR buffer to user's fake tensor."""
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
raw_shape = ir_buffer.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
)
return user_function(fake_tensor)
return internal_input_gen_fn
return {
i: create_internal_input_gen_fn(
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
)
for i, user_gen_fn in index_based_fns.items()
if i < len(inputs)
}
def _create_fallback_choice(
name: str,
default_impl: Callable[..., Any],
fake_output: torch.Tensor,
kwargs: dict[str, Any],
) -> ExternKernelChoice:
"""Create fallback choice for default implementation."""
def fallback_wrapper(*args: Any) -> Any:
return default_impl(*args, **kwargs)
return ExternKernelChoice(
kernel=fallback_wrapper,
name=f"{name}_fallback_default",
has_out_variant=False,
op_overload=default_impl,
use_fallback_kernel=True,
)
def _create_parameter_variants(
decompositions: list[Callable[..., Any]],
tuning_knob: dict[str, list[Any]],
) -> list[Any]: # Returns partial objects which are callable
"""Create parameter variants for decompositions using tuning knob.
Args:
decompositions: Base implementation functions
tuning_knob: Parameter tuning dict with parameter names and value lists
Returns:
List of variant functions with all parameter combinations
"""
# Validate parameter values
for param_name, param_values in tuning_knob.items():
if not param_values or not isinstance(param_values, (list, tuple)):
raise TypeError(
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
)
# Generate all combinations of parameter values using Cartesian product
import itertools
param_names = list(tuning_knob.keys())
param_values_lists = list(tuning_knob.values())
param_combinations = list(itertools.product(*param_values_lists))
# Create variants for each decomposition with each parameter combination
variants = []
for decomp_fn in decompositions:
for param_combo in param_combinations:
# Create kwargs dict for this combination
param_kwargs = dict(zip(param_names, param_combo))
# Create partial function with all parameters
variant = functools.partial(decomp_fn, **param_kwargs)
param_suffix = "_".join(
f"{name}_{value}" for name, value in param_kwargs.items()
)
variant.__name__ = f"{decomp_fn.__name__}_{param_suffix}" # type: ignore[attr-defined]
variants.append(variant)
return variants
def autotune_custom_op(
name: str,
decompositions: list[Callable[..., Any]],
inputs: list[Any],
kwargs: Optional[dict[str, Any]] = None,
default_impl: Optional[Callable[..., Any]] = None,
user_input_gen_fns: Optional[
dict[str, Callable[[torch.Tensor], torch.Tensor]]
] = None,
) -> Union[TensorBox, Any]:
"""Autotune custom operations by comparing multiple decomposition implementations.
Currently supports SINGLE OUTPUT custom ops only.
TODO: Add support for multiple output custom ops (tuple/list returns).
This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
Args:
name: Unique identifier for the autotuning operation
decompositions: List of alternative implementation functions to benchmark
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
kwargs: Non-tensor parameters to pass to decomposition functions
default_impl: Original custom op implementation used as fallback
user_input_gen_fns: Optional custom input generators for benchmarking.
Maps input indices to functions that take fake tensors
and return real tensors for performance measurement.
Returns:
IR node representing the optimized operation result
Raises:
TypeError: If decompositions is not a list/tuple
RuntimeError: If no inputs or no valid choices generated
"""
if kwargs is None:
kwargs = {}
if not isinstance(decompositions, (list, tuple)):
raise TypeError(
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
)
if not inputs:
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
template = SubgraphTemplate(name=name)
choices = template.generate_custom_op_choices(
name=name,
decompositions=list(decompositions),
input_nodes=list(inputs),
kwargs=kwargs,
)
# Add default implementation as fallback
if default_impl and hasattr(default_impl, "_op"):
fallback_name = f"{name}_fallback_default"
from torch._inductor.select_algorithm import extern_kernels
# Skip if extern_kernel already registered to avoid duplicate registration error
if not hasattr(extern_kernels, fallback_name):
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
fake_output = default_impl(*fake_inputs, **kwargs)
fallback_choice = _create_fallback_choice(
name, default_impl, fake_output, kwargs
)
fallback_choice.maybe_append_choice(
choices=choices,
input_nodes=list(inputs),
layout=FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
),
)
if not choices:
raise RuntimeError(f"No valid choices generated for {name}")
# Convert user input generation functions to internal format
input_gen_fns = {}
if user_input_gen_fns:
import inspect
arg_names = (
list(inspect.signature(decompositions[0]).parameters.keys())
if decompositions
else []
)
input_gen_fns = _create_user_input_gen_fns(
inputs, arg_names, user_input_gen_fns
)
return autotune_select_algorithm(
name=name,
choices=choices,
input_nodes=list(inputs),
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
)
def register_custom_op_autotuning(
custom_op: torch._ops.OpOverload,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
name: Optional[str] = None,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
) -> None:
"""Register custom op for autotuning with explicit configs.
Uses config-based API where each config specifies a decomposition function
with its parameter values.
Args:
custom_op: Custom operation to register
configs: List of CustomOpConfig objects or callable functions
name: Operation name (default: "{op_name}_autotuned")
input_gen_fns: Custom input generators for benchmarking
Examples:
register_custom_op_autotuning(
torch.ops.mylib.attention.default,
configs=[
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
CustomOpConfig(fallback_impl), # No params
],
input_gen_fns={
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
"""
if not isinstance(configs, (list, tuple)):
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
if not configs:
raise ValueError("At least one config must be provided")
# Convert configs to decomposition functions
final_decompositions = []
for config in configs:
if isinstance(config, CustomOpConfig):
# CustomOpConfig object
final_decompositions.append(config.create_variant())
elif callable(config):
# Direct callable function
final_decompositions.append(config)
else:
raise TypeError(
f"Each config must be a CustomOpConfig object or callable function, "
f"got {type(config)}"
)
if name is None:
name = f"{custom_op._name}_autotuned"
@functools.wraps(custom_op)
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
# Extract tensor inputs and non-tensor parameters
tensor_inputs, non_tensor_kwargs = _extract_tensor_inputs(args, kwargs)
result = autotune_custom_op(
name=name,
decompositions=final_decompositions,
inputs=tensor_inputs,
kwargs=non_tensor_kwargs,
default_impl=custom_op,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
lowerings[custom_op] = autotuning_lowering

View File

@ -2879,7 +2879,6 @@ class AlgorithmSelectorCache(PersistentCache):
)
timings = do_autotuning(choices, precompile_fn)
# if timings is empty, we really have no choice but to return a semi-random
# choice. returning the first `ExternKernelCaller` is probably the safest bet
# in this case, since it will generally be the ATen kernel. if there are no
@ -3483,6 +3482,7 @@ class AlgorithmSelectorCache(PersistentCache):
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
if config.autotune_num_choices_displayed == 0:
return
# when autotune_num_choices_displayed is None, [:None] means all
n = config.autotune_num_choices_displayed
top_k = sorted(timings, key=timings.__getitem__)[:n]