mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
resolve comments:fix input_gen_fns API; tensor shape infer with siz_hints instead of ir_node_to_tensor; seperated non_tensor_args for Functools.partial for kwargs ; added Single output - layout assertion; lint
This commit is contained in:
@ -1,17 +1,14 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Test suite for custom operation autotuning integration with PyTorch Inductor.
|
||||
Tests for custom operation autotuning with PyTorch Inductor.
|
||||
|
||||
This module tests the custom op autotuning system, which allows users to provide
|
||||
multiple decomposition implementations of custom operations and automatically
|
||||
select the best performing one through Inductor's autotuning system.
|
||||
Users can register custom ops with multiple decomposition implementations and let
|
||||
Inductor automatically select the best performing variant. Key features tested:
|
||||
|
||||
The tests cover:
|
||||
1. Custom op registration and autotuning for different operations (RMSNorm, MLP, Attention)
|
||||
2. Numerical equivalence between different decomposition implementations
|
||||
3. End-to-end compilation and performance validation
|
||||
4. Fallback behavior when decompositions fail
|
||||
5. Max-autotune integration with extended configuration sets
|
||||
- Name-based input generators (use argument names instead of indices)
|
||||
- Dynamic shape handling across multiple compilations
|
||||
- Parametric tuning with tuning_knob for combinatorial parameter exploration
|
||||
- Numerical correctness and performance validation
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -154,8 +151,8 @@ class TestCustomOpAutoTune(TestCase):
|
||||
return input_tensor, gate_weight, up_weight, down_weight
|
||||
|
||||
@skipIfXpu
|
||||
def test_rmsnorm_custom_op_autotune(self):
|
||||
"""Test RMSNorm autotuning with multiple decomposition variants showcasing different performance characteristics."""
|
||||
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
|
||||
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
|
||||
test_op_name = f"test_lib::rmsnorm_{id(self)}"
|
||||
|
||||
def rmsnorm_decomposition1(
|
||||
@ -216,31 +213,42 @@ class TestCustomOpAutoTune(TestCase):
|
||||
rmsnorm_decomposition3,
|
||||
]
|
||||
|
||||
# Example of user-friendly input generation functions
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
decompositions=decompositions,
|
||||
name="test_rmsnorm_autotuned",
|
||||
input_gen_fns={
|
||||
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.02, # Small values for input
|
||||
1: lambda fake_tensor: torch.ones_like(
|
||||
fake_tensor, device=self.device
|
||||
), # Ones for weight
|
||||
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
|
||||
"weight": lambda weight: torch.ones_like(weight, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
# Test inputs
|
||||
input_tensor, weight = self._create_rmsnorm_inputs()
|
||||
# Test multiple shapes to verify dynamic shape handling
|
||||
test_shapes = [(2, 16, 128), (8, 32, 256)]
|
||||
|
||||
# Test numerical equivalence for all decompositions
|
||||
self._assert_implementations_equivalent(
|
||||
decompositions, (input_tensor, weight), "RMSNorm"
|
||||
)
|
||||
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
expected = rmsnorm_decomposition1(input_tensor, weight)
|
||||
self._run_autotune_test(op_object, (input_tensor, weight), expected, "RMSNorm")
|
||||
# Test numerical equivalence for all decompositions
|
||||
self._assert_implementations_equivalent(
|
||||
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
expected = rmsnorm_decomposition1(input_tensor, weight)
|
||||
self._run_autotune_test(
|
||||
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
def test_mlp_custom_op_autotune(self):
|
||||
@ -327,14 +335,22 @@ class TestCustomOpAutoTune(TestCase):
|
||||
decompositions=decompositions,
|
||||
name="test_mlp_autotuned",
|
||||
input_gen_fns={
|
||||
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.1, # Input tensor
|
||||
1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.05, # Gate weight
|
||||
2: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.05, # Up weight
|
||||
3: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.05, # Down weight
|
||||
"input_tensor": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
"gate_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"up_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"down_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
@ -408,10 +424,14 @@ class TestCustomOpAutoTune(TestCase):
|
||||
max_autotune_configs={"k_splits": [8, 16, 128, 256, 512]},
|
||||
name="test_decompose_k_autotuned",
|
||||
input_gen_fns={
|
||||
0: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.1,
|
||||
1: lambda fake_tensor: torch.randn_like(fake_tensor, device=self.device)
|
||||
* 0.1,
|
||||
"a": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix A
|
||||
"b": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix B
|
||||
},
|
||||
)
|
||||
|
||||
@ -498,8 +518,8 @@ class TestCustomOpAutoTune(TestCase):
|
||||
tuning_knob={"method": [0, 1, 2, 3, 4]},
|
||||
name="parametric_norm_autotuned",
|
||||
input_gen_fns={
|
||||
0: lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
1: lambda t: torch.ones_like(t, device=self.device),
|
||||
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
"weight": lambda t: torch.ones_like(t, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
@ -579,8 +599,10 @@ class TestCustomOpAutoTune(TestCase):
|
||||
tuning_knob={"scale_mode": [1, 2, 3], "chunk_size": [16, 32]},
|
||||
name="multi_param_autotuned",
|
||||
input_gen_fns={
|
||||
0: lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
1: lambda t: torch.ones(t.shape[-1], device=self.device, dtype=t.dtype),
|
||||
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
"factor": lambda t: torch.ones(
|
||||
t.shape[-1], device=self.device, dtype=t.dtype
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -8,6 +8,7 @@ from torch._inductor import ir
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import (
|
||||
Buffer,
|
||||
FixedLayout,
|
||||
get_free_symbols,
|
||||
get_symbolic_inputs,
|
||||
gm_original_output_strides,
|
||||
@ -249,7 +250,17 @@ class SubgraphTemplate(KernelTemplate):
|
||||
]
|
||||
|
||||
self._validate_stride_consistency(name, decompositions, layouts)
|
||||
layout = layouts[0] # All layouts have equivalent stride now
|
||||
|
||||
# Assert single output layout - assumes custom ops have one output tensor
|
||||
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
|
||||
assert all(
|
||||
layout.device == layouts[0].device
|
||||
and layout.dtype == layouts[0].dtype
|
||||
and layout.size == layouts[0].size
|
||||
for layout in layouts
|
||||
), f"All decompositions for '{name}' must produce equivalent output layouts"
|
||||
|
||||
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
|
||||
|
||||
choices = []
|
||||
for decomp in decompositions:
|
||||
@ -302,20 +313,49 @@ class SubgraphTemplate(KernelTemplate):
|
||||
kwargs: dict[str, Any],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> Layout:
|
||||
"""Infer output layout for custom ops using the default implementation when available."""
|
||||
"""Infer output layout for custom ops using the default implementation when available.
|
||||
|
||||
Note that the Subgraph assumes custom ops return exactly one tensor so far.
|
||||
TODO: Add support for multiple output custom ops.
|
||||
"""
|
||||
import functools
|
||||
|
||||
from torch._inductor.ir import FixedLayout, ir_node_to_tensor
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
# Assert kwargs contain only non-tensor arguments for functools.partial
|
||||
for key, value in kwargs.items():
|
||||
assert not isinstance(value, (torch.Tensor, Buffer)), (
|
||||
f"kwargs['{key}'] contains tensor {type(value)}. "
|
||||
f"Tensor arguments should be in input_nodes, not kwargs. "
|
||||
f"Only scalar/non-tensor parameters should be in kwargs."
|
||||
)
|
||||
|
||||
# Use default_impl if available, otherwise use first decomposition
|
||||
impl = default_impl if default_impl is not None else decompositions[0]
|
||||
|
||||
with V.fake_mode:
|
||||
example_inputs = [ir_node_to_tensor(inp) for inp in input_nodes]
|
||||
fn = functools.partial(impl, **kwargs)
|
||||
example_inputs = []
|
||||
for inp in input_nodes:
|
||||
raw_shape = inp.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
|
||||
)
|
||||
example_inputs.append(fake_tensor)
|
||||
|
||||
fn = functools.partial(
|
||||
impl, **kwargs
|
||||
) # kwargs must be non-tensor for partial
|
||||
output = fn(*example_inputs)
|
||||
|
||||
# Assert single output
|
||||
assert isinstance(output, torch.Tensor), (
|
||||
f"Expected single tensor output, got {type(output)}. "
|
||||
f"Multi-output custom ops not yet supported in autotuning."
|
||||
)
|
||||
|
||||
return FixedLayout(
|
||||
device=output.device,
|
||||
dtype=output.dtype,
|
||||
|
@ -25,6 +25,8 @@ def _extract_tensor_inputs(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Extract tensor inputs from mixed args/kwargs.
|
||||
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
|
||||
Non-tensor kwargs are later functools.partial'd into decomposition functions.
|
||||
|
||||
Args:
|
||||
args: Positional arguments (mix of tensors and scalars)
|
||||
@ -36,16 +38,16 @@ def _extract_tensor_inputs(
|
||||
tensor_inputs = []
|
||||
non_tensor_kwargs = {}
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, (TensorBox, Buffer)) or (
|
||||
hasattr(arg, "dtype") and hasattr(arg, "shape")
|
||||
):
|
||||
# Process args and kwargs: separate tensor inputs and non tensor args
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(arg)
|
||||
else:
|
||||
# Add non-tensor positional args to kwargs with generated names
|
||||
non_tensor_kwargs[f"arg_{i}"] = arg
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, (TensorBox, Buffer)) or (
|
||||
hasattr(value, "dtype") and hasattr(value, "shape")
|
||||
):
|
||||
if isinstance(value, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(value)
|
||||
else:
|
||||
non_tensor_kwargs[key] = value
|
||||
@ -55,46 +57,50 @@ def _extract_tensor_inputs(
|
||||
|
||||
def _create_user_input_gen_fns(
|
||||
inputs: list[Any],
|
||||
user_input_gen_fns: dict[int, Callable[[torch.Tensor], torch.Tensor]],
|
||||
arg_names: list[str],
|
||||
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
|
||||
) -> dict[int, Callable[[Any], torch.Tensor]]:
|
||||
"""Convert user input generators to internal format.
|
||||
"""Convert user input generators from name-based to index-based format.
|
||||
Inductor autotune's input_gen_fns expects index of arg_names as key.
|
||||
|
||||
Args:
|
||||
inputs: List of input IR nodes from compilation
|
||||
user_input_gen_fns: User-provided input generation functions
|
||||
|
||||
Returns:
|
||||
Dict mapping indices to internal input generation functions
|
||||
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
||||
"""
|
||||
internal_input_gen_fns = {}
|
||||
from torch._inductor import config
|
||||
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
||||
index_based_fns = {}
|
||||
|
||||
for name, gen_fn in user_input_gen_fns.items():
|
||||
if name in name_to_index:
|
||||
index_based_fns[name_to_index[name]] = gen_fn
|
||||
else:
|
||||
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
|
||||
|
||||
def create_internal_input_gen_fn(
|
||||
user_function: Callable[[torch.Tensor], torch.Tensor],
|
||||
template: torch.Tensor,
|
||||
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
|
||||
) -> Callable[[Any], torch.Tensor]:
|
||||
"""Create internal input generator that converts IR buffer to user's fake tensor."""
|
||||
|
||||
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
|
||||
fake_tensor_for_user = torch.empty(
|
||||
template.shape,
|
||||
dtype=template.dtype,
|
||||
device="meta",
|
||||
raw_shape = ir_buffer.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
return user_function(fake_tensor_for_user)
|
||||
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
|
||||
)
|
||||
return user_function(fake_tensor)
|
||||
|
||||
return internal_input_gen_fn
|
||||
|
||||
for i, user_gen_fn in user_input_gen_fns.items():
|
||||
if i >= len(fake_inputs):
|
||||
continue
|
||||
|
||||
fake_template = fake_inputs[i]
|
||||
internal_input_gen_fns[i] = create_internal_input_gen_fn(
|
||||
user_gen_fn, fake_template
|
||||
return {
|
||||
i: create_internal_input_gen_fn(
|
||||
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
|
||||
)
|
||||
|
||||
return internal_input_gen_fns
|
||||
for i, user_gen_fn in index_based_fns.items()
|
||||
if i < len(inputs)
|
||||
}
|
||||
|
||||
|
||||
# Global cache for fallback choices to avoid duplicate creation
|
||||
@ -107,13 +113,7 @@ def _get_or_create_fallback_choice(
|
||||
fake_output: torch.Tensor,
|
||||
kwargs: dict[str, Any],
|
||||
) -> ExternKernelChoice:
|
||||
"""Get or create fallback choice for default implementation."""
|
||||
cache_key = (id(default_impl), name, tuple(sorted(kwargs.items())))
|
||||
|
||||
if cache_key not in _fallback_choice_cache:
|
||||
|
||||
def fallback_wrapper(*args: Any) -> Any:
|
||||
return default_impl(*args, **kwargs)
|
||||
"""Create fallback choice for default implementation."""
|
||||
|
||||
fallback_name = f"{name}_fallback_{default_impl._name}"
|
||||
_fallback_choice_cache[cache_key] = ExternKernelChoice(
|
||||
@ -142,14 +142,10 @@ def _create_parameter_variants(
|
||||
"""
|
||||
# Validate parameter values
|
||||
for param_name, param_values in tuning_knob.items():
|
||||
if not isinstance(param_values, (list, tuple)):
|
||||
if not param_values or not isinstance(param_values, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
|
||||
)
|
||||
if not param_values:
|
||||
raise ValueError(
|
||||
f"At least one parameter value must be provided for '{param_name}'"
|
||||
)
|
||||
|
||||
# Generate all combinations of parameter values using Cartesian product
|
||||
import itertools
|
||||
@ -167,8 +163,6 @@ def _create_parameter_variants(
|
||||
|
||||
# Create partial function with all parameters
|
||||
variant = functools.partial(decomp_fn, **param_kwargs)
|
||||
|
||||
# Generate descriptive name
|
||||
param_suffix = "_".join(
|
||||
f"{name}_{value}" for name, value in param_kwargs.items()
|
||||
)
|
||||
@ -185,11 +179,14 @@ def autotune_custom_op(
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
user_input_gen_fns: Optional[
|
||||
dict[int, Callable[[torch.Tensor], torch.Tensor]]
|
||||
dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
] = None,
|
||||
) -> Union[TensorBox, Any]:
|
||||
"""Autotune custom operations by comparing multiple decomposition implementations.
|
||||
|
||||
Currently supports SINGLE OUTPUT custom ops only.
|
||||
TODO: Add support for multiple output custom ops (tuple/list returns).
|
||||
|
||||
This function generates multiple implementation choices for a custom operation and
|
||||
uses Inductor's autotuning system to select the best performing variant at runtime.
|
||||
|
||||
@ -231,24 +228,28 @@ def autotune_custom_op(
|
||||
|
||||
# Add default implementation as fallback
|
||||
if default_impl and hasattr(default_impl, "_op"):
|
||||
# Get output shape/dtype by calling default implementation with fake inputs
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
fake_output = default_impl(*fake_inputs, **kwargs)
|
||||
fallback_name = f"{name}_fallback_default"
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
fallback_choice = _get_or_create_fallback_choice(
|
||||
name, default_impl, fake_output, kwargs
|
||||
)
|
||||
fallback_choice.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=FixedLayout(
|
||||
device=fake_output.device,
|
||||
dtype=fake_output.dtype,
|
||||
size=fake_output.shape,
|
||||
stride=fake_output.stride(),
|
||||
),
|
||||
)
|
||||
# Skip if extern_kernel already registered to avoid duplicate registration error
|
||||
if not hasattr(extern_kernels, fallback_name):
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
fake_output = default_impl(*fake_inputs, **kwargs)
|
||||
|
||||
fallback_choice = _create_fallback_choice(
|
||||
name, default_impl, fake_output, kwargs
|
||||
)
|
||||
fallback_choice.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=FixedLayout(
|
||||
device=fake_output.device,
|
||||
dtype=fake_output.dtype,
|
||||
size=fake_output.shape,
|
||||
stride=fake_output.stride(),
|
||||
),
|
||||
)
|
||||
|
||||
if not choices:
|
||||
raise RuntimeError(f"No valid choices generated for {name}")
|
||||
@ -256,7 +257,16 @@ def autotune_custom_op(
|
||||
# Convert user input generation functions to internal format
|
||||
input_gen_fns = {}
|
||||
if user_input_gen_fns:
|
||||
input_gen_fns = _create_user_input_gen_fns(inputs, user_input_gen_fns)
|
||||
import inspect
|
||||
|
||||
arg_names = (
|
||||
list(inspect.signature(decompositions[0]).parameters.keys())
|
||||
if decompositions
|
||||
else []
|
||||
)
|
||||
input_gen_fns = _create_user_input_gen_fns(
|
||||
inputs, arg_names, user_input_gen_fns
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name=name,
|
||||
@ -271,7 +281,7 @@ def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
name: Optional[str] = None,
|
||||
input_gen_fns: Optional[dict[int, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
tuning_knob: Optional[dict[str, list[Any]]] = None,
|
||||
max_autotune_configs: Optional[dict[str, list[Any]]] = None,
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user