Compare commits

...

7 Commits

Author SHA1 Message Date
d789319a3a remove not used flag 2025-11-04 16:35:51 -08:00
ac3f93d4ae resolve comment to use none gm instead of hasattr; remove mlp tests due to not useful 2025-11-04 16:34:28 -08:00
3e46c7075e resolve comments; refactor; remove mm.py usage 2025-11-04 15:19:48 -08:00
009f83025f refine code and lint 2025-11-03 23:52:08 -08:00
03b3b4b513 use decomposition table during lowering 2025-11-03 23:07:09 -08:00
c522624e02 support inline fusion for tuned_mm 2025-11-02 22:46:42 -08:00
4b59d48f84 Add custom op autotuning support with simplified API
- Implement register_custom_op_autotuning() that accepts CustomOpDef directly
- Support multiple CustomOpConfig variants for autotuning
- Add comprehensive tests for custom op autotuning
- Enable inline fusion for custom ops with epilogue operations
- Update scheduler and config for custom op integration
2025-11-02 21:28:57 -08:00
5 changed files with 150 additions and 150 deletions

View File

@ -216,115 +216,6 @@ class TestCustomOpAutoTune(TestCase):
test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants.
Validates parametric tuning where the same decomposition function uses different
algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights).
"""
test_op_name = f"test_lib::mlp_{id(self)}"
def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""
if method == 0:
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
elif method == 1:
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]
input_2d = input_tensor.view(-1, hidden_dim)
gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)
gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)
return output_2d.view(*batch_shape, output_dim)
@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)
# Use explicit config with method parameter as tuning knob
register_custom_op_autotuning(
test_mlp_op,
configs=[
CustomOpConfig(method=0),
CustomOpConfig(method=1),
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
# Create test inputs
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
# Test autotuning
self._run_autotune_test(
test_mlp_op,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
@ -335,12 +226,12 @@ class TestCustomOpAutoTune(TestCase):
@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parametric tuning for k_splits values.
"""Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale).
Validates numerical parameter sweep where k_splits controls how the K dimension
is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]).
Validates that the custom op encapsulates the entire fused operation with parametric
tuning for k_splits values controlling how the K dimension is decomposed.
"""
test_op_name = f"test_lib::decompose_k_{id(self)}"
test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}"
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
@ -363,19 +254,23 @@ class TestCustomOpAutoTune(TestCase):
return torch.sum(result, dim=0) # [m, n]
@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
def matmul_relu_epilogue_op(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
return decompose_k_implementation(a, b, k_splits)
"""Matmul with decompose_k + bias + relu + scale (complete epilogue fusion)."""
matmul_result = decompose_k_implementation(a, b, k_splits)
biased = matmul_result + bias
activated = torch.relu(biased)
scaled = activated * 2.0
return scaled
@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
@matmul_relu_epilogue_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
# Register autotuning with different k_splits values using decomposition function
# Register autotuning with different k_splits values
register_custom_op_autotuning(
test_decompose_k_op,
matmul_relu_epilogue_op,
configs=[
CustomOpConfig(k_splits=2),
CustomOpConfig(k_splits=4),
@ -385,7 +280,7 @@ class TestCustomOpAutoTune(TestCase):
CustomOpConfig(k_splits=64),
CustomOpConfig(k_splits=128),
],
name="test_decompose_k_autotuned",
name="matmul_relu_epilogue_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
@ -395,12 +290,45 @@ class TestCustomOpAutoTune(TestCase):
fake_tensor, device=self.device
)
* 0.1,
"bias": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
},
)
# Create test inputs
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK")
bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1
# Compile the model using the custom op
@torch.compile
def test_model(a, b, bias):
return matmul_relu_epilogue_op(a, b, bias)
torch._dynamo.reset()
with config.patch(
max_autotune=True,
benchmark_fusion=True,
):
compiled_result = test_model(a, b, bias)
def reference_model(a, b, bias):
matmul_result = a @ b
biased = matmul_result + bias
activated = torch.relu(biased)
scaled = activated * 2.0
return scaled
expected = reference_model(a, b, bias)
torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
)
@skipIfXpu
def test_multi_parameter_tuning(self):

View File

@ -23,6 +23,22 @@ from torch._inductor.virtualized import V
log = logging.getLogger(__name__)
def inline_subgraph_to_ir_nodes(
gm: torch.fx.GraphModule, inputs: list[Any], name: str
) -> Any:
"""Inline a subgraph by converting its FX operations to individual IR nodes.
This converts a subgraph to multiple ComputedBuffer nodes (fusable),
enabling epilogue fusion with subsequent operations.
Returns:
TensorBox containing the final operation result as individual IR nodes
"""
from torch._inductor.lowering import process_subgraph_nodes
return process_subgraph_nodes(gm, inputs)
class SubgraphChoiceCaller(ir.ChoiceCaller):
"""
Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary
@ -260,7 +276,14 @@ class SubgraphTemplate(KernelTemplate):
# decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
from torch.fx.experimental.proxy_tensor import make_fx
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
from ..decomposition import select_decomp_table
decomposition_table = select_decomp_table()
return make_fx(
functools.partial(decomp, **decomp_kwargs),
decomposition_table=decomposition_table,
)(*args)
# Generate descriptive name for this variant
variant_name = self._generate_variant_name(decomp, decomp_kwargs)

View File

@ -5,6 +5,7 @@ import logging
from typing import Any, Callable, Optional, Union
import torch
from torch._inductor import config
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
from torch._inductor.lowering import lowerings, validate_ir
@ -157,7 +158,6 @@ def _adapt_user_input_gen_fns(
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
"""
from torch._inductor import config
name_to_index = {name: i for i, name in enumerate(arg_names)}
index_based_fns = {}
@ -237,6 +237,7 @@ def autotune_custom_op(
This function generates multiple implementation choices for a custom operation and
uses Inductor's autotuning system to select the best performing variant at runtime.
After selecting the best choice, applies inline fusion if the winning choice has a graph.
Args:
name: Unique identifier for the autotuning operation
@ -318,14 +319,34 @@ def autotune_custom_op(
)
input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns)
return autotune_select_algorithm(
# Run autotuning and get both result and winning choice
selected_result, winning_choice = autotune_select_algorithm(
name=name,
choices=choices,
input_nodes=list(inputs),
layout=choices[0].layout,
input_gen_fns=input_gen_fns,
return_choice=True,
)
# Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl)
if winning_choice.gm is not None:
log.debug(
"Inlining winning choice: %s (name=%s)",
getattr(winning_choice, "name", type(winning_choice).__name__),
name,
)
from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes
return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)
log.debug(
"Winning choice does not support inlining: %s (name=%s)",
getattr(winning_choice, "name", type(winning_choice).__name__),
name,
)
return selected_result
def register_custom_op_autotuning(
custom_op: torch._library.custom_ops.CustomOpDef,
@ -358,7 +379,7 @@ def register_custom_op_autotuning(
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
},
)
"""
from torch._library.custom_ops import CustomOpDef
@ -376,12 +397,12 @@ def register_custom_op_autotuning(
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
processed_configs = []
for config in configs:
if isinstance(config, CustomOpConfig):
processed_configs.append(config)
for cfg in configs:
if isinstance(cfg, CustomOpConfig):
processed_configs.append(cfg)
else:
raise TypeError(
f"Each config must be a CustomOpConfig object, got {type(config)}"
f"Each config must be a CustomOpConfig object, got {type(cfg)}"
)
if not processed_configs:
@ -400,14 +421,12 @@ def register_custom_op_autotuning(
decompositions = []
non_tensor_args = []
for config in processed_configs:
decomp = config.get_decomposition(default_impl=default_impl)
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
# Merge config params with runtime kwargs (runtime takes precedence)
merged_kwargs = _merge_config_and_runtime_kwargs(
config.params, runtime_kwargs
)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(

View File

@ -7307,6 +7307,35 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
return list(map(TensorBox.create, result)) # type: ignore[call-overload]
def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]):
"""Process nodes from a FX graph by executing them through V.graph.
This is a common pattern for executing a subgraph's nodes:
- Placeholder nodes are mapped to the provided args
- Output nodes return their result
- Other nodes are executed via V.graph.run_node
"""
output = None
for i, node in enumerate(graph_module.graph.nodes):
if node.op == "placeholder":
assert node not in V.graph.env
V.graph.env[node] = args[i]
continue
elif node.op == "output":
output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs)
else:
assert node not in V.graph.env
V.graph.env[node] = V.graph.run_node(node)
if output is None:
raise RuntimeError("No output node found in graph")
return output
# Import the control_deps_op HOP for lowering
from torch._inductor.fx_passes.control_dependencies import control_deps
@ -7334,21 +7363,11 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
arg_offset = 2 # first two args (additional_deps, subgraph)
assert len(args) + arg_offset == len(original_args)
output = None
operation_len = len(V.graph.operations)
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
if node.op == "placeholder":
assert node not in V.graph.env
V.graph.env[node] = args[i]
continue
elif node.op == "output":
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
else:
assert node not in V.graph.env
V.graph.env[node] = V.graph.run_node(node)
# Process subgraph nodes using the shared helper
output = process_subgraph_nodes(subgraph_fn.graph_module, list(args))
assert output is not None and additional_deps

View File

@ -2145,6 +2145,8 @@ class ExternKernelChoice:
# There is no src hash for ExternKernelChoice in the traditional sense
# so we indicate this by returning None
self.src_hash = None
# By default GraphModule is None for extern kernels if not set
self.gm = None
def to_callable(self):
return getattr(extern_kernels, self.name)
@ -2317,6 +2319,7 @@ class ExternKernelCaller(ChoiceCaller):
self.choice = choice
self.kwargs = kwargs or {}
self.has_out_variant = has_out_variant
self.gm = choice.gm
def __str__(self) -> str:
return f"ExternKernelCaller({self.choice.call_name()})"
@ -2700,6 +2703,7 @@ class AlgorithmSelectorCache(PersistentCache):
precompilation_timeout_seconds: int = 60 * 60,
return_multi_template=False,
best_config_future=None,
return_choice=False,
):
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -2971,18 +2975,25 @@ class AlgorithmSelectorCache(PersistentCache):
"Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s",
node,
)
if return_choice:
return node, choice
return node
node = choices[0].output_node()
choice = choices[0]
log.debug(
"Autotuning returned empty timings, falling back to first choice: %s",
node,
)
if return_choice:
return node, choice
return node
# if we got any timings at all, pick the best of those
choice = min(timings, key=timings.__getitem__)
node = choice.output_node()
log.debug("Autotuning selected choice: %s", node)
if return_choice:
return node, choice
return node
def make_precompile_fn(