resolve comments:fix input_gen_fns API; tensor shape infer with siz_hints instead of ir_node_to_tensor; seperated non_tensor_args for Functools.partial for kwargs ; added Single output - layout assertion; lint

This commit is contained in:
Tianren Gao
2025-10-16 00:48:02 -07:00
parent dd83ee3762
commit f194b740f7
3 changed files with 188 additions and 116 deletions

View File

@ -1,17 +1,14 @@
# Owner(s): ["module: inductor"]
"""
Test suite for custom operation autotuning integration with PyTorch Inductor.
Tests for custom operation autotuning with PyTorch Inductor.
This module tests the custom op autotuning system, which allows users to provide
multiple decomposition implementations of custom operations and automatically
select the best performing one through Inductor's autotuning system.
Users can register custom ops with multiple decomposition implementations and let
Inductor automatically select the best performing variant. Key features tested:
The tests cover:
1. Custom op registration and autotuning for different operations (RMSNorm, MLP, Attention)
2. Numerical equivalence between different decomposition implementations
3. End-to-end compilation and performance validation
4. Fallback behavior when decompositions fail
5. Max-autotune integration with extended configuration sets
- Name-based input generators (use argument names instead of indices)
- Dynamic shape handling across multiple compilations
- Parametric tuning with tuning_knob for combinatorial parameter exploration
- Numerical correctness and performance validation
"""
import torch
@ -154,8 +151,8 @@ class TestCustomOpAutoTune(TestCase):
return input_tensor, gate_weight, up_weight, down_weight
@skipIfXpu
def test_rmsnorm_custom_op_autotune(self):
"""Test RMSNorm autotuning with multiple decomposition variants showcasing different performance characteristics."""
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
test_op_name = f"test_lib::rmsnorm_{id(self)}"
def rmsnorm_decomposition1(
@ -216,31 +213,42 @@ class TestCustomOpAutoTune(TestCase):
rmsnorm_decomposition3,
]
# Example of user-friendly input generation functions
register_custom_op_autotuning(
op_object.default,
decompositions=decompositions,
name="test_rmsnorm_autotuned",
input_gen_fns={
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.02, # Small values for input
1: lambda fake_tensor: torch.ones_like(
fake_tensor, device=self.device
), # Ones for weight
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
"weight": lambda weight: torch.ones_like(weight, device=self.device),
},
)
# Test inputs
input_tensor, weight = self._create_rmsnorm_inputs()
# Test multiple shapes to verify dynamic shape handling
test_shapes = [(2, 16, 128), (8, 32, 256)]
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), "RMSNorm"
)
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(op_object, (input_tensor, weight), expected, "RMSNorm")
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
@ -327,14 +335,22 @@ class TestCustomOpAutoTune(TestCase):
decompositions=decompositions,
name="test_mlp_autotuned",
input_gen_fns={
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.1, # Input tensor
1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.05, # Gate weight
2: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.05, # Up weight
3: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.05, # Down weight
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
@ -408,10 +424,14 @@ class TestCustomOpAutoTune(TestCase):
max_autotune_configs={"k_splits": [8, 16, 128, 256, 512]},
name="test_decompose_k_autotuned",
input_gen_fns={
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.1,
1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
* 0.1,
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix A
"b": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1, # Matrix B
},
)
@ -498,8 +518,8 @@ class TestCustomOpAutoTune(TestCase):
tuning_knob={"method": [0, 1, 2, 3, 4]},
name="parametric_norm_autotuned",
input_gen_fns={
0: lambda t: torch.randn_like(t, device=self.device) * 0.1,
1: lambda t: torch.ones_like(t, device=self.device),
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"weight": lambda t: torch.ones_like(t, device=self.device),
},
)
@ -579,8 +599,10 @@ class TestCustomOpAutoTune(TestCase):
tuning_knob={"scale_mode": [1, 2, 3], "chunk_size": [16, 32]},
name="multi_param_autotuned",
input_gen_fns={
0: lambda t: torch.randn_like(t, device=self.device) * 0.1,
1: lambda t: torch.ones(t.shape[-1], device=self.device, dtype=t.dtype),
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"factor": lambda t: torch.ones(
t.shape[-1], device=self.device, dtype=t.dtype
),
},
)

View File

@ -8,6 +8,7 @@ from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import (
Buffer,
FixedLayout,
get_free_symbols,
get_symbolic_inputs,
gm_original_output_strides,
@ -249,7 +250,17 @@ class SubgraphTemplate(KernelTemplate):
]
self._validate_stride_consistency(name, decompositions, layouts)
layout = layouts[0] # All layouts have equivalent stride now
# Assert single output layout - assumes custom ops have one output tensor
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
assert all(
layout.device == layouts[0].device
and layout.dtype == layouts[0].dtype
and layout.size == layouts[0].size
for layout in layouts
), f"All decompositions for '{name}' must produce equivalent output layouts"
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
choices = []
for decomp in decompositions:
@ -302,20 +313,49 @@ class SubgraphTemplate(KernelTemplate):
kwargs: dict[str, Any],
default_impl: Optional[Callable[..., Any]] = None,
) -> Layout:
"""Infer output layout for custom ops using the default implementation when available."""
"""Infer output layout for custom ops using the default implementation when available.
Note that the Subgraph assumes custom ops return exactly one tensor so far.
TODO: Add support for multiple output custom ops.
"""
import functools
from torch._inductor.ir import FixedLayout, ir_node_to_tensor
from torch._inductor.virtualized import V
# Assert kwargs contain only non-tensor arguments for functools.partial
for key, value in kwargs.items():
assert not isinstance(value, (torch.Tensor, Buffer)), (
f"kwargs['{key}'] contains tensor {type(value)}. "
f"Tensor arguments should be in input_nodes, not kwargs. "
f"Only scalar/non-tensor parameters should be in kwargs."
)
# Use default_impl if available, otherwise use first decomposition
impl = default_impl if default_impl is not None else decompositions[0]
with V.fake_mode:
example_inputs = [ir_node_to_tensor(inp) for inp in input_nodes]
fn = functools.partial(impl, **kwargs)
example_inputs = []
for inp in input_nodes:
raw_shape = inp.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
fake_tensor = torch.empty(
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
)
example_inputs.append(fake_tensor)
fn = functools.partial(
impl, **kwargs
) # kwargs must be non-tensor for partial
output = fn(*example_inputs)
# Assert single output
assert isinstance(output, torch.Tensor), (
f"Expected single tensor output, got {type(output)}. "
f"Multi-output custom ops not yet supported in autotuning."
)
return FixedLayout(
device=output.device,
dtype=output.dtype,

View File

@ -25,6 +25,8 @@ def _extract_tensor_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Extract tensor inputs from mixed args/kwargs.
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
Non-tensor kwargs are later functools.partial'd into decomposition functions.
Args:
args: Positional arguments (mix of tensors and scalars)
@ -36,16 +38,16 @@ def _extract_tensor_inputs(
tensor_inputs = []
non_tensor_kwargs = {}
for arg in args:
if isinstance(arg, (TensorBox, Buffer)) or (
hasattr(arg, "dtype") and hasattr(arg, "shape")
):
# Process args and kwargs: separate tensor inputs and non tensor args
for i, arg in enumerate(args):
if isinstance(arg, (TensorBox, Buffer)):
tensor_inputs.append(arg)
else:
# Add non-tensor positional args to kwargs with generated names
non_tensor_kwargs[f"arg_{i}"] = arg
for key, value in kwargs.items():
if isinstance(value, (TensorBox, Buffer)) or (
hasattr(value, "dtype") and hasattr(value, "shape")
):
if isinstance(value, (TensorBox, Buffer)):
tensor_inputs.append(value)
else:
non_tensor_kwargs[key] = value
@ -55,46 +57,50 @@ def _extract_tensor_inputs(
def _create_user_input_gen_fns(
inputs: list[Any],
user_input_gen_fns: dict[int, Callable[[torch.Tensor], torch.Tensor]],
arg_names: list[str],
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
) -> dict[int, Callable[[Any], torch.Tensor]]:
"""Convert user input generators to internal format.
"""Convert user input generators from name-based to index-based format.
Inductor autotune's input_gen_fns expects index of arg_names as key.
Args:
inputs: List of input IR nodes from compilation
user_input_gen_fns: User-provided input generation functions
Returns:
Dict mapping indices to internal input generation functions
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
internal_input_gen_fns = {}
from torch._inductor import config
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
for name, gen_fn in user_input_gen_fns.items():
if name in name_to_index:
index_based_fns[name_to_index[name]] = gen_fn
else:
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
def create_internal_input_gen_fn(
user_function: Callable[[torch.Tensor], torch.Tensor],
template: torch.Tensor,
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
) -> Callable[[Any], torch.Tensor]:
"""Create internal input generator that converts IR buffer to user's fake tensor."""
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
fake_tensor_for_user = torch.empty(
template.shape,
dtype=template.dtype,
device="meta",
raw_shape = ir_buffer.get_size()
concrete_shape = V.graph.sizevars.size_hints(
raw_shape, fallback=config.unbacked_symint_fallback
)
return user_function(fake_tensor_for_user)
fake_tensor = torch.empty(
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
)
return user_function(fake_tensor)
return internal_input_gen_fn
for i, user_gen_fn in user_input_gen_fns.items():
if i >= len(fake_inputs):
continue
fake_template = fake_inputs[i]
internal_input_gen_fns[i] = create_internal_input_gen_fn(
user_gen_fn, fake_template
return {
i: create_internal_input_gen_fn(
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
)
return internal_input_gen_fns
for i, user_gen_fn in index_based_fns.items()
if i < len(inputs)
}
# Global cache for fallback choices to avoid duplicate creation
@ -107,13 +113,7 @@ def _get_or_create_fallback_choice(
fake_output: torch.Tensor,
kwargs: dict[str, Any],
) -> ExternKernelChoice:
"""Get or create fallback choice for default implementation."""
cache_key = (id(default_impl), name, tuple(sorted(kwargs.items())))
if cache_key not in _fallback_choice_cache:
def fallback_wrapper(*args: Any) -> Any:
return default_impl(*args, **kwargs)
"""Create fallback choice for default implementation."""
fallback_name = f"{name}_fallback_{default_impl._name}"
_fallback_choice_cache[cache_key] = ExternKernelChoice(
@ -142,14 +142,10 @@ def _create_parameter_variants(
"""
# Validate parameter values
for param_name, param_values in tuning_knob.items():
if not isinstance(param_values, (list, tuple)):
if not param_values or not isinstance(param_values, (list, tuple)):
raise TypeError(
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
)
if not param_values:
raise ValueError(
f"At least one parameter value must be provided for '{param_name}'"
)
# Generate all combinations of parameter values using Cartesian product
import itertools
@ -167,8 +163,6 @@ def _create_parameter_variants(
# Create partial function with all parameters
variant = functools.partial(decomp_fn, **param_kwargs)
# Generate descriptive name
param_suffix = "_".join(
f"{name}_{value}" for name, value in param_kwargs.items()
)
@ -185,11 +179,14 @@ def autotune_custom_op(
kwargs: Optional[dict[str, Any]] = None,
default_impl: Optional[Callable[..., Any]] = None,
user_input_gen_fns: Optional[
dict[int, Callable[[torch.Tensor], torch.Tensor]]
dict[str, Callable[[torch.Tensor], torch.Tensor]]
] = None,
) -> Union[TensorBox, Any]:
"""Autotune custom operations by comparing multiple decomposition implementations.
Currently supports SINGLE OUTPUT custom ops only.
TODO: Add support for multiple output custom ops (tuple/list returns).
This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
@ -231,24 +228,28 @@ def autotune_custom_op(
# Add default implementation as fallback
if default_impl and hasattr(default_impl, "_op"):
# Get output shape/dtype by calling default implementation with fake inputs
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
fake_output = default_impl(*fake_inputs, **kwargs)
fallback_name = f"{name}_fallback_default"
from torch._inductor.select_algorithm import extern_kernels
fallback_choice = _get_or_create_fallback_choice(
name, default_impl, fake_output, kwargs
)
fallback_choice.maybe_append_choice(
choices=choices,
input_nodes=list(inputs),
layout=FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
),
)
# Skip if extern_kernel already registered to avoid duplicate registration error
if not hasattr(extern_kernels, fallback_name):
with V.fake_mode:
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
fake_output = default_impl(*fake_inputs, **kwargs)
fallback_choice = _create_fallback_choice(
name, default_impl, fake_output, kwargs
)
fallback_choice.maybe_append_choice(
choices=choices,
input_nodes=list(inputs),
layout=FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
),
)
if not choices:
raise RuntimeError(f"No valid choices generated for {name}")
@ -256,7 +257,16 @@ def autotune_custom_op(
# Convert user input generation functions to internal format
input_gen_fns = {}
if user_input_gen_fns:
input_gen_fns = _create_user_input_gen_fns(inputs, user_input_gen_fns)
import inspect
arg_names = (
list(inspect.signature(decompositions[0]).parameters.keys())
if decompositions
else []
)
input_gen_fns = _create_user_input_gen_fns(
inputs, arg_names, user_input_gen_fns
)
return autotune_select_algorithm(
name=name,
@ -271,7 +281,7 @@ def register_custom_op_autotuning(
custom_op: torch._ops.OpOverload,
decompositions: list[Callable[..., Any]],
name: Optional[str] = None,
input_gen_fns: Optional[dict[int, Callable[[torch.Tensor], torch.Tensor]]] = None,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
tuning_knob: Optional[dict[str, list[Any]]] = None,
max_autotune_configs: Optional[dict[str, list[Any]]] = None,
) -> None: