Compare commits

...

26 Commits

Author SHA1 Message Date
d3df658752 clean up code 2025-10-26 18:22:30 -07:00
25db928d8b refactor subgraph and fix lint 2025-10-26 17:08:18 -07:00
f81e55d2be rebase + support optional decomposition; refine customopconfig 2025-10-25 22:52:02 -07:00
237bfde80e param suffix and test refine 2025-10-20 21:38:39 -07:00
b01c23ea73 update API morepscificiely and use config 2025-10-20 21:38:39 -07:00
84457f6f10 resolve comments:fix input_gen_fns API; tensor shape infer with siz_hints instead of ir_node_to_tensor; seperated non_tensor_args for Functools.partial for kwargs ; added Single output - layout assertion; lint 2025-10-20 21:38:39 -07:00
ef2654092c simplify API and improve code readability 2025-10-20 21:38:39 -07:00
af01078a3a fix cpu tests and lint 2025-10-20 21:38:39 -07:00
191163a1d5 simplify logic 2025-10-20 21:38:39 -07:00
16a85e83ae support custom op tuning for parameters; passing tuning knob as an arg 2025-10-20 21:38:39 -07:00
ba3a90b26c resolve input_gen_fn call and register faketensor input 2025-10-20 21:38:39 -07:00
bd36441960 remove input_gen_fns from user's interface; moving to inductor internal 2025-10-20 21:38:39 -07:00
cf46d7064d revert changes and clean up code 2025-10-20 21:38:39 -07:00
d183d9f36e Tianren/custom op autotune fix (#164689)
* remove redundant catch and callable list; refine decorator to avoid using register_lowering; fix document

* wip fallback to default

* fix lint

* refine test case

* include default implemenation to the choices

* ensure test passed for default implementation correctly; fix lint

* clean up test

* refine test structure

* refine test, fix lint, remove new tempalte

* clean up code and refine

* clean up code

* simply code and lint
2025-10-20 21:38:39 -07:00
3aab26a012 fix lint issue and deal with device properly 2025-10-20 21:38:39 -07:00
a8b92f8627 add fallback and seperate cpu and gpu backend 2025-10-20 21:38:39 -07:00
b6a1f1dcaf fix lint issues 2025-10-20 21:38:38 -07:00
c4d56497f4 fix lint 2025-10-20 21:38:38 -07:00
540038b2a6 linter fix 2025-10-20 21:38:38 -07:00
25ef7dc9a5 lint 2025-10-20 21:38:38 -07:00
8c7614f086 refine tests and update customop 2025-10-20 21:38:38 -07:00
6475254b78 refine test for rmsnorm variants and custom op, clean up code, test passed for 3 decompositions 2025-10-20 21:38:38 -07:00
55c2d5ed7a update test 2025-10-20 21:38:38 -07:00
cabc6be017 add test 2025-10-20 21:38:38 -07:00
ba623bf246 add tests 2025-10-20 21:38:38 -07:00
c9ede96ac5 initial tests and modified custom op kernel template 2025-10-20 21:38:38 -07:00
4 changed files with 1171 additions and 4 deletions

View File

@ -0,0 +1,564 @@
# Owner(s): ["module: inductor"]
"""
Tests for custom operation autotuning with PyTorch Inductor.
Users can register custom ops with multiple decomposition implementations and let
Inductor automatically select the best performing variant. Key features tested:
- Name-based input generators (use argument names instead of indices)
- Dynamic shape handling across multiple compilations
- Parametric tuning with tuning_knob for combinatorial parameter exploration
- Numerical correctness and performance validation
"""
import torch
from torch._inductor import config
from torch._inductor.kernel.custom_op import (
CustomOpConfig,
register_custom_op_autotuning,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
torch.set_float32_matmul_precision("high")
class TestCustomOpAutoTune(TestCase):
"""Test custom operation autotuning functionality."""
def setUp(self) -> None:
"""Set up test environment with appropriate device and dtype."""
super().setUp()
self.device = "cuda" if HAS_GPU else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
def _create_test_configs(self):
"""Create common test configurations for different sizes."""
return [
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
]
def _run_autotune_test(self, op_object, inputs, expected, test_name):
"""Shared test infrastructure for autotuning tests."""
@torch.compile
def test_model(*args):
return op_object(*args)
torch._dynamo.reset()
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
with config.patch(
max_autotune=True,
max_autotune_gemm_backends=autotune_backends,
fx_graph_cache=False,
benchmark_kernel=True,
):
compiled_result = test_model(*inputs)
self.assertEqual(
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
)
torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
msg=f"{test_name} numerical mismatch",
)
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
"""Utility to assert that all implementations produce equivalent results."""
implementations = [(func.__name__, func) for func in decompositions]
results = {}
for name, impl in implementations:
result = impl(*inputs)
results[name] = result
# Basic sanity checks
self.assertTrue(
torch.isfinite(result).all(),
f"{op_name} {name} produced non-finite values",
)
# Verify numerical equivalence
reference_name, reference_result = next(iter(results.items()))
for name, result in results.items():
if name != reference_name:
rtol = 1e-1 if "Approximated" in name else 1e-2
atol = 1e-1 if "Approximated" in name else 1e-2
torch.testing.assert_close(
result,
reference_result,
rtol=rtol,
atol=atol,
msg=f"{op_name} {name} differs from {reference_name}",
)
def _create_rmsnorm_inputs(self, batch_size=32, seq_len=2048, hidden_dim=512):
"""Create test inputs for RMSNorm operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
return input_tensor, weight
def _create_mlp_inputs(
self,
batch_size=2,
seq_len=32,
hidden_dim=512,
intermediate_dim=1024,
output_dim=256,
):
"""Create test inputs for MLP operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
gate_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
up_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
down_weight = torch.randn(
intermediate_dim,
output_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
return input_tensor, gate_weight, up_weight, down_weight
@skipIfXpu
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
test_op_name = f"test_lib::rmsnorm_{id(self)}"
def rmsnorm_decomposition1(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Variance-based approach: compute variance then rsqrt."""
variance = x.pow(2).mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
return x * rstd * weight
def rmsnorm_decomposition2(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
x_var = x
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
if weight is not None:
x = x * weight
return x
def rmsnorm_decomposition3(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
x_squared = x.pow(2)
variance = x_squared.mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
normalized = x * rstd
if weight is not None:
normalized = normalized * weight
return normalized
@torch.library.custom_op(test_op_name, mutates_args=())
def test_rmsnorm_op(
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
return torch.nn.functional.rms_norm(
input_tensor, input_tensor.shape[-1:], weight, eps=eps
)
@test_rmsnorm_op.register_fake
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
return torch.empty_like(input_tensor)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
decompositions = [
rmsnorm_decomposition1,
rmsnorm_decomposition2,
rmsnorm_decomposition3,
]
register_custom_op_autotuning(
op_object.default,
configs=[CustomOpConfig(decomp) for decomp in decompositions],
name="test_rmsnorm_autotuned",
input_gen_fns={
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
"weight": lambda weight: torch.ones_like(weight, device=self.device),
},
)
# Test multiple shapes to verify dynamic shape handling
test_shapes = [(2, 16, 128), (8, 32, 256)]
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
test_op_name = f"test_lib::mlp_{id(self)}"
def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""
if method == 0:
# Separate matmuls: standard implementation with torch.matmul
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
elif method == 1:
# Batched approach: uses torch.mm with reshaped tensors
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]
input_2d = input_tensor.view(-1, hidden_dim)
gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)
gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)
return output_2d.view(*batch_shape, output_dim)
elif method == 2:
# Fused weights approach: concatenate then split weights
# Concatenate gate and up weights for one matrix multiply
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
fused_proj = torch.matmul(input_tensor, fused_weight)
intermediate_dim = gate_weight.shape[1]
gate_proj, up_proj = fused_proj.split(
[intermediate_dim, intermediate_dim], dim=-1
)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Use explicit configs with method parameter as tuning knob
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(method=1), # Batched approach
CustomOpConfig(method=2), # Fused weights
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
# Create test inputs using the original helper method
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
for method in [1, 2]:
result = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
torch.testing.assert_close(
result,
expected,
rtol=1e-5,
atol=1e-5,
msg=f"Method {method} not equivalent to method 0",
)
# Test autotuning - all should be mathematically equivalent
self._run_autotune_test(
op_object,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
return a, b
@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parameter tuning for k_splits values using decomposition functions."""
test_op_name = f"test_lib::decompose_k_{id(self)}"
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - Python implementation."""
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
k_parts = k // k_splits
B = k_splits
a_reshaped = torch.permute(
a.reshape(m, B, k_parts), (1, 0, 2)
) # [B, m, k_parts]
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
return torch.sum(result, dim=0) # [m, n]
@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
return decompose_k_implementation(a, b, k_splits)
@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
lib_name, op_name = test_op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_name)
# Register autotuning with different k_splits values using decomposition function
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(k_splits=32),
CustomOpConfig(k_splits=64),
CustomOpConfig(k_splits=128),
CustomOpConfig(k_splits=256),
],
name="test_decompose_k_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"b": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
},
)
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
@skipIfXpu
def test_multi_parameter_tuning(self):
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
op_name = f"test_lib::multi_param_{id(self)}"
def multi_param_scaling(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
"""Different scaling approaches controlled by scale_mode parameter."""
if scale_mode == 1:
# Simple broadcasting
return x * factor
elif scale_mode == 2:
# Process in chunks
batch_size, seq_len = x.shape[:2]
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end]
chunks.append(chunk * factor)
return torch.cat(chunks, dim=1)
elif scale_mode == 3:
# Using einsum for scaling
return torch.einsum("...i,i->...i", x, factor)
@torch.library.custom_op(op_name, mutates_args=())
def multi_param_op(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
return multi_param_scaling(x, factor, scale_mode, chunk_size)
@multi_param_op.register_fake
def _(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
):
return torch.empty_like(x)
lib_name, op_suffix = op_name.split("::")
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
register_custom_op_autotuning(
op_object.default,
configs=[
CustomOpConfig(scale_mode=1), # Broadcast
CustomOpConfig(scale_mode=2, chunk_size=16), # Chunked 16
CustomOpConfig(scale_mode=2, chunk_size=32), # Chunked 32
CustomOpConfig(scale_mode=3), # Einsum
],
name="multi_param_autotuned",
input_gen_fns={
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"factor": lambda t: torch.ones(
t.shape[-1], device=self.device, dtype=t.dtype
),
},
)
# Create test inputs
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
# Verify numerical equivalence across all approaches
expected_result = test_x * test_factor
# Test each scale_mode variant
configs = [
(1, 16), # broadcast, chunk_size ignored
(2, 16), # chunked with size 16
(2, 32), # chunked with size 32
(3, 16), # einsum, chunk_size ignored
]
for i, (scale_mode, chunk_size) in enumerate(configs):
result = multi_param_scaling(
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
)
torch.testing.assert_close(
result,
expected_result,
rtol=1e-5,
atol=1e-5,
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
)
# Test autotuning
self._run_autotune_test(
op_object, (test_x, test_factor), expected_result, "MultiParam"
)
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,6 @@
import itertools
import logging
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union
import torch
import torch._inductor.config as config
@ -8,6 +8,7 @@ from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Buffer,
FixedLayout,
get_free_symbols,
get_symbolic_inputs,
gm_original_output_strides,
@ -110,7 +111,11 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
bm_func([*sym_inputs, *args])
if config.profile_bandwidth_with_do_bench_using_profiling:
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
if self.layout.device.type == "cpu":
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
else:
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
def hash_key(self) -> str:
return "-".join(
@ -181,9 +186,11 @@ class SubgraphTemplate(KernelTemplate):
Generate a SubgraphChoiceCaller instance for autotuning.
Args:
name: The name for this subgraph choice
input_nodes: List of input nodes to the subgraph
layout: Memory layout information for the output
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
make_fx_graph: Callable that creates the FX graph for this subgraph
description: Optional description of this choice
**kwargs: Additional keyword arguments
Returns:
@ -197,3 +204,165 @@ class SubgraphTemplate(KernelTemplate):
description=description,
make_fx_graph=make_fx_graph,
)
def generate_custom_op_choices(
self,
name: str,
decompositions: list[Callable[..., Any]],
input_nodes: list[Buffer],
non_tensor_args: list[dict[str, Any]],
default_impl: Optional[Callable[..., Any]] = None,
) -> list[SubgraphChoiceCaller]:
"""
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
This method extends SubgraphTemplate to support custom op decompositions,
allowing multiple implementations to compete in autotuning.
Args:
name: Base name for the choices
decompositions: List of decomposition functions to compete in autotuning
input_nodes: List of tensor inputs. All tensor arguments must be passed here.
non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition.
default_impl: Default implementation for layout inference
Returns:
List of SubgraphChoiceCaller instances for autotuning
"""
if not decompositions:
return []
assert len(decompositions) == len(non_tensor_args), (
f"decompositions and non_tensor_args must have same length, "
f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
)
# Infer layouts and ensure layout consistency for fair autotuning comparison
layouts = [
self._infer_custom_op_layout(input_nodes, decomp, kwargs, default_impl)
for decomp, kwargs in zip(decompositions, non_tensor_args)
]
# Validate all decompositions produce equivalent layouts for fair comparison
self._validate_layout_equivalence(name, decompositions, layouts)
layout = layouts[0] # All layouts are now validated to be equivalent
choices: list[SubgraphChoiceCaller] = []
for decomp, decomp_kwargs in zip(decompositions, non_tensor_args):
# Create make_fx_graph function for this decomposition
import functools
def make_fx_graph(
*args: Any,
decomp: Callable[..., Any] = decomp,
decomp_kwargs: dict[str, Any] = decomp_kwargs,
) -> Any:
# decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
from torch.fx.experimental.proxy_tensor import make_fx
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
# Generate descriptive name for this variant
variant_name = self._generate_variant_name(decomp, decomp_kwargs)
choice = self.generate(
name=f"{name}_{variant_name}",
input_nodes=input_nodes,
layout=layout,
make_fx_graph=make_fx_graph,
description=f"CustomOp {decomp.__name__}",
)
choices.append(choice)
return choices
def _generate_variant_name(
self, decomp: Callable[..., Any], kwargs: dict[str, Any]
) -> str:
"""Generate a descriptive name for a decomposition variant with its parameters."""
base_name = decomp.__name__
if not kwargs:
return base_name
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items()))
return f"{base_name}_{param_suffix}"
def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None:
"""Validate that kwargs contains only non-tensor arguments."""
for key, value in kwargs.items():
assert not isinstance(value, (torch.Tensor, Buffer)), (
f"kwargs['{key}'] contains tensor {type(value)}. "
f"Tensor arguments should be in input_nodes, not kwargs. "
f"Only scalar/non-tensor parameters should be in kwargs."
)
def _validate_layout_equivalence(
self,
op_name: str,
decompositions: list[Callable[..., Any]],
layouts: list[Layout],
) -> None:
"""Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning."""
if not layouts:
return
reference = layouts[0]
for i, layout in enumerate(layouts[1:], start=1):
if (layout.device, layout.dtype, layout.size, layout.stride) != (
reference.device,
reference.dtype,
reference.size,
reference.stride,
):
raise AssertionError(
f"Layout mismatch in custom op '{op_name}': "
f"decomposition '{decompositions[i].__name__}' produces "
f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) "
f"but '{decompositions[0].__name__}' produces "
f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})"
)
def _infer_custom_op_layout(
self,
input_nodes: list[Buffer],
function_decomposition: Callable[..., Any],
kwargs: dict[str, Any],
default_impl: Optional[Callable[..., Any]] = None,
) -> Layout:
"""Infer output layout for custom ops using the default implementation when available.
Note that the Subgraph assumes custom ops return exactly one tensor output.
TODO: Add support for multiple output custom ops.
"""
import functools
from torch._inductor.virtualized import V
# Assert kwargs contain only non-tensor arguments
self._validate_non_tensor_kwargs(kwargs)
with V.fake_mode:
example_inputs = []
for inp in input_nodes:
raw_shape = inp.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
)
example_inputs.append(fake_tensor)
fn = functools.partial(function_decomposition, **kwargs)
output = fn(*example_inputs)
# Assert single output
assert isinstance(output, torch.Tensor), (
f"Expected single tensor output, got {type(output)}. "
f"Multi-output custom ops not yet supported in autotuning."
)
return FixedLayout(
device=output.device,
dtype=output.dtype,
size=output.shape,
stride=output.stride(),
)

View File

@ -0,0 +1,434 @@
# Owner(s): ["module: inductor"]
import functools
import logging
from typing import Any, Callable, Optional, Union
import torch
from torch import _ops
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
from torch._inductor.lowering import lowerings, validate_ir
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
)
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
class CustomOpConfig:
"""Config for custom op autotuning.
Specifies optional decomposition function with parameter values.
Each config creates exactly one variant.
Args:
decomposition: Optional functions to autotune. If not provided, default will be used.
**params: Parameters passed to the function
Examples:
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
CustomOpConfig(head_dim=32, method='chunked')
"""
def __init__(
self,
decomposition: Optional[Callable[..., Any]] = None,
**params: Any,
):
if decomposition is not None and not callable(decomposition):
raise TypeError(
f"decomposition must be callable, got {type(decomposition)}"
)
self.decomposition = decomposition
self.params = params
def get_decomposition(
self, default_impl: Optional[Callable[..., Any]] = None
) -> Callable[..., Any]:
"""Return the decomposition function for this config.
When decomposition is not specified, return the default implementation
from the custom op's registration.
"""
if self.decomposition is not None:
return self.decomposition
# If no decomposition specified, get Python implementation from custom op
if default_impl and isinstance(default_impl, (_ops.OpOverload, str)):
from torch._library.custom_ops import _maybe_get_opdef
op_def = _maybe_get_opdef(default_impl)
if op_def is not None and hasattr(op_def, "_init_fn"):
return op_def._init_fn
raise TypeError(
f"Could not extract Python implementation from {default_impl}. "
f"Please register customop or provide a decomposition function."
)
def get_name(self) -> str:
"""Get the name for this config variant."""
param_suffix = (
"_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
if self.params
else ""
)
base_name = self.decomposition.__name__ if self.decomposition else "default"
return f"{base_name}_{param_suffix}" if param_suffix else base_name
def __repr__(self) -> str:
if self.params:
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
decomp_name = (
self.decomposition.__name__ if self.decomposition else "default"
)
return f"CustomOpConfig({decomp_name}, {params_str})"
decomp_name = self.decomposition.__name__ if self.decomposition else "default"
return f"CustomOpConfig({decomp_name})"
__all__ = [
"autotune_custom_op",
"register_custom_op_autotuning",
"CustomOpConfig",
]
def _extract_tensor_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Extract tensor inputs from mixed args/kwargs.
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
Non-tensor kwargs are later functools.partial'd into decomposition functions.
Args:
args: Positional arguments (mix of tensors and scalars)
kwargs: Keyword arguments (mix of tensors and scalars)
Returns:
Tuple of (tensor_inputs_list, non_tensor_kwargs)
"""
tensor_inputs = []
non_tensor_kwargs = {}
# Process args and kwargs: separate tensor inputs and non tensor args
for i, arg in enumerate(args):
if isinstance(arg, (TensorBox, Buffer)):
tensor_inputs.append(arg)
else:
# Add non-tensor positional args to kwargs with generated names
non_tensor_kwargs[f"arg_{i}"] = arg
for key, value in kwargs.items():
if isinstance(value, (TensorBox, Buffer)):
tensor_inputs.append(value)
else:
non_tensor_kwargs[key] = value
return tensor_inputs, non_tensor_kwargs
def _merge_config_and_runtime_kwargs(
config_params: dict[str, Any],
runtime_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Merge config parameters with runtime kwargs. Runtime kwargs take precedence.
If there are conflicts, log a warning and use runtime value.
Args:
config_params: Parameters from CustomOpConfig
runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs
Returns:
Merged kwargs dictionary with runtime values taking precedence
"""
merged_kwargs = config_params.copy()
# Check for conflicts and let runtime kwargs dominate
conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys())
for key in conflicts:
log.warning(
"Parameter '%s' specified both in CustomOpConfig (%s) "
"and at runtime (%s). Using runtime value.",
key,
config_params[key],
runtime_kwargs[key],
)
# Runtime kwargs override config params
merged_kwargs.update(runtime_kwargs)
return merged_kwargs
def _create_user_input_gen_fns(
inputs: list[Any],
arg_names: list[str],
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
) -> dict[int, Callable[[Any], torch.Tensor]]:
"""Convert user input generators from name-based to index-based format.
Inductor autotune's input_gen_fns expects index of arg_names as key.
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
from torch._inductor import config
name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
for name, gen_fn in user_input_gen_fns.items():
if name in name_to_index:
index_based_fns[name_to_index[name]] = gen_fn
else:
log.warning(
"Unknown argument name '%s' in input_gen_fns. "
"Available argument names: %s",
name,
list(name_to_index.keys()),
)
def create_internal_input_gen_fn(
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
) -> Callable[[Any], torch.Tensor]:
"""Create internal input generator that converts IR buffer to user's fake tensor."""
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
raw_shape = ir_buffer.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
)
return user_function(fake_tensor)
return internal_input_gen_fn
return {
i: create_internal_input_gen_fn(
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
)
for i, user_gen_fn in index_based_fns.items()
if i < len(inputs)
}
def _create_fallback_choice(
name: str,
default_impl: Callable[..., Any],
fake_output: torch.Tensor,
kwargs: dict[str, Any],
) -> ExternKernelChoice:
"""Create fallback choice for default implementation."""
def fallback_wrapper(*args: Any) -> Any:
return default_impl(*args, **kwargs)
return ExternKernelChoice(
kernel=fallback_wrapper,
name=f"{name}_fallback_default",
has_out_variant=False,
op_overload=default_impl,
use_fallback_kernel=True,
)
def autotune_custom_op(
name: str,
decompositions: list[Callable[..., Any]],
inputs: list[Any],
non_tensor_args: list[dict[str, Any]],
default_impl: Optional[Callable[..., Any]] = None,
user_input_gen_fns: Optional[
dict[str, Callable[[torch.Tensor], torch.Tensor]]
] = None,
) -> Union[TensorBox, Any]:
"""Autotune custom operations by comparing multiple decomposition implementations.
Currently supports SINGLE OUTPUT custom ops only.
TODO: Add support for multiple output custom ops (tuple/list returns).
This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
Args:
name: Unique identifier for the autotuning operation
decompositions: List of alternative implementation functions to benchmark
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg
default_impl: Original custom op implementation used as fallback
user_input_gen_fns: Optional custom input generators for benchmarking.
Maps input indices to functions that take fake tensors
and return real tensors for performance measurement.
Returns:
IR node representing the optimized operation result
Raises:
TypeError: If decompositions is not a list/tuple
RuntimeError: If no inputs or no valid choices generated
"""
if not isinstance(decompositions, (list, tuple)):
raise TypeError(
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
)
if not inputs:
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
if len(decompositions) != len(non_tensor_args):
raise ValueError(
f"decompositions and non_tensor_args must have same length, "
f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
)
template = SubgraphTemplate(name=name)
choices = template.generate_custom_op_choices(
name=name,
decompositions=decompositions,
input_nodes=list(inputs),
non_tensor_args=non_tensor_args,
)
# Add default implementation as fallback
if default_impl and hasattr(default_impl, "_op"):
fallback_name = f"{name}_fallback_default"
from torch._inductor.select_algorithm import extern_kernels
# Skip if extern_kernel already registered to avoid duplicate registration error
if not hasattr(extern_kernels, fallback_name):
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
fallback_kwargs = non_tensor_args[0] if non_tensor_args else {}
fake_output = default_impl(*fake_inputs, **fallback_kwargs)
fallback_choice = _create_fallback_choice(
name, default_impl, fake_output, fallback_kwargs
)
fallback_choice.maybe_append_choice(
choices=choices,
input_nodes=list(inputs),
layout=FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
),
)
if not choices:
raise RuntimeError(f"No valid choices generated for {name}")
# Convert user input generation functions to internal format
input_gen_fns = {}
if user_input_gen_fns:
import inspect
arg_names = (
list(inspect.signature(decompositions[0]).parameters.keys())
if decompositions
else []
)
input_gen_fns = _create_user_input_gen_fns(
inputs, arg_names, user_input_gen_fns
)
return autotune_select_algorithm(
name=name,
choices=choices,
input_nodes=list(inputs),
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
)
def register_custom_op_autotuning(
custom_op: torch._ops.OpOverload,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
name: Optional[str] = None,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
) -> None:
"""Register custom op for autotuning with custom_op configs where each config
specifies a decomposition implementation function with its parameter values.
Args:
custom_op: Custom operation to register
configs: List of CustomOpConfig objects or callable functions
name: Operation name (default: "{op_name}_autotuned")
input_gen_fns: Custom input generators for benchmarking
Examples:
register_custom_op_autotuning(
torch.ops.mylib.attention.default,
configs=[
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
CustomOpConfig(fallback_impl), # No params
],
input_gen_fns={
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
"""
if not isinstance(configs, (list, tuple)):
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
processed_configs = []
for config in configs:
if isinstance(config, CustomOpConfig):
processed_configs.append(config)
else:
raise TypeError(
f"Each config must be a CustomOpConfig object, got {type(config)}"
)
if not processed_configs:
raise ValueError("At least one config must be provided")
if name is None:
name = f"{custom_op._name}_autotuned"
@functools.wraps(custom_op)
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
# Extract tensor inputs and non-tensor parameters (runtime kwargs)
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Prepare decompositions and kwargs by merging customop config params with runtime kwargs
decompositions = []
non_tensor_args = []
for config in processed_configs:
decomp = config.get_decomposition(default_impl=custom_op)
decompositions.append(decomp)
# Merge config params with runtime kwargs (runtime takes precedence)
merged_kwargs = _merge_config_and_runtime_kwargs(
config.params, runtime_kwargs
)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
default_impl=custom_op,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
lowerings[custom_op] = autotuning_lowering

View File

@ -2919,7 +2919,6 @@ class AlgorithmSelectorCache(PersistentCache):
)
timings = do_autotuning(choices, precompile_fn)
# if timings is empty, we really have no choice but to return a semi-random
# choice. returning the first `ExternKernelCaller` is probably the safest bet
# in this case, since it will generally be the ATen kernel. if there are no
@ -3524,6 +3523,7 @@ class AlgorithmSelectorCache(PersistentCache):
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
if config.autotune_num_choices_displayed == 0:
return
# when autotune_num_choices_displayed is None, [:None] means all
n = config.autotune_num_choices_displayed
top_k = sorted(timings, key=timings.__getitem__)[:n]