Compare commits

...

5 Commits

Author SHA1 Message Date
0a95232b30 clean up a bit 2025-11-18 11:54:39 -08:00
8d55dbf8c6 cleanup 2025-11-18 10:20:20 -08:00
33c74a1e36 dispatch only happens when more than 1 impl;merge cond for same impl 2025-11-17 22:40:25 -08:00
5ef1a13f5f update code 2025-11-16 23:48:37 -08:00
ca37f9cf39 add changes for dynamic range tuning 2025-11-13 00:00:43 -08:00
4 changed files with 716 additions and 51 deletions

View File

@ -430,6 +430,93 @@ class TestCustomOpAutoTune(TestCase):
multi_param_op, (test_x, test_factor), expected_result, "MultiParam"
)
@skipIfXpu
def test_dynamic_range_tuning(self):
"""Test dynamic input range-based autotuning.
Validates that:
- All implementations produce equivalent results
- Autotuning selects best implementation per range
- torch.cond dispatch function is generated correctly
"""
test_op_name = f"test_lib::dynamic_range_{id(self)}"
def short_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return torch.einsum("bsh,h->bsh", x, weight)
def medium_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_dim = x.shape
chunk_size = 256
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end, :]
chunks.append(chunk * weight)
return torch.cat(chunks, dim=1)
def long_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return x * weight.view(1, 1, -1)
@torch.library.custom_op(test_op_name, mutates_args=())
def dynamic_range_op(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return x * weight
@dynamic_range_op.register_fake
def _(x: torch.Tensor, weight: torch.Tensor):
return torch.empty_like(x)
register_custom_op_autotuning(
dynamic_range_op,
configs=[
CustomOpConfig(short_sequence_impl),
CustomOpConfig(medium_sequence_impl),
CustomOpConfig(long_sequence_impl),
],
name="dynamic_range_autotuned",
dispatch_on=("x", 1),
split_points=[512, 2048],
input_gen_fns={
"x": lambda fake: torch.randn_like(fake, device=self.device) * 0.1,
"weight": lambda fake: torch.ones_like(fake, device=self.device),
},
)
# Verify all implementations produce equivalent results
test_cases = [
(2, 256, 128),
(2, 1024, 128),
(2, 4096, 128),
]
for batch_size, seq_len, hidden_dim in test_cases:
test_x = torch.randn(
batch_size, seq_len, hidden_dim, device=self.device, dtype=self.dtype
)
test_weight = torch.ones(hidden_dim, device=self.device, dtype=self.dtype)
expected = test_x * test_weight
for impl_name, impl_fn in [
("short", short_sequence_impl),
("medium", medium_sequence_impl),
("long", long_sequence_impl),
]:
result = impl_fn(test_x, test_weight)
torch.testing.assert_close(
result,
expected,
rtol=1e-5,
atol=1e-5,
msg=f"{impl_name} implementation differs for seq_len={seq_len}",
)
# Test autotuning with compilation
self._run_autotune_test(
dynamic_range_op,
(test_x, test_weight),
expected,
f"DynamicRange_seq{seq_len}",
)
if __name__ == "__main__":
run_tests()

View File

@ -6900,43 +6900,83 @@ class TMADescriptorStable(TMADescriptor):
class SubgraphBuffer(ExternKernel):
"""Represents a subgraph with optional multi-range dispatch."""
def __init__(
self,
layout: Layout,
input_nodes: list[Buffer],
gm: torch.fx.GraphModule,
gm: torch.fx.GraphModule
| list[tuple[tuple[int, int | float], torch.fx.GraphModule]],
example_inputs: list[Any],
subgraph_name: str,
dispatch_dim_index: int | None = None,
):
super().__init__(None, layout, input_nodes)
self.gm = gm
self.example_inputs = example_inputs
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
if isinstance(gm, list):
self.is_multi_range = True
self.dispatch_dim_index = dispatch_dim_index
assert dispatch_dim_index is not None
assert is_node_sequence(self.inputs)
sym_inputs = get_symbolic_inputs(self.inputs)
self.range_subgraphs: list[
tuple[tuple[int, int | float], SubgraphBuffer]
] = []
for sym_inp in sym_inputs:
self.subgraph.graph_inputs[sym_inp.name] = sym_inp
self.subgraph.graph_input_names.append(sym_inp.name)
for (range_start, range_end), range_gm in gm:
range_subgraph = SubgraphBuffer(
layout=layout,
input_nodes=input_nodes,
gm=range_gm,
example_inputs=example_inputs,
subgraph_name=f"{subgraph_name}_range_{range_start}_{range_end}",
dispatch_dim_index=None,
)
self.range_subgraphs.append(((range_start, range_end), range_subgraph))
self.sym_inputs = [sym_var.name for sym_var in sym_inputs]
self.subgraph = None
self.sym_inputs = []
import torch._inductor.config as inductor_config
else:
self.is_multi_range = False
self.gm = gm
self.dispatch_dim_index = None
with V.set_graph_handler(self.subgraph):
# Don't bother autotuning on Triton here
with inductor_config.patch(
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
self.subgraph.run(*self.example_inputs)
self.subgraph = V.graph.make_subgraph(
self.gm, example_inputs, subgraph_name
)
assert is_node_sequence(self.inputs)
sym_inputs = get_symbolic_inputs(self.inputs)
for sym_inp in sym_inputs:
self.subgraph.graph_inputs[sym_inp.name] = sym_inp
self.subgraph.graph_input_names.append(sym_inp.name)
self.sym_inputs = [sym_var.name for sym_var in sym_inputs]
import torch._inductor.config as inductor_config
with V.set_graph_handler(self.subgraph):
with inductor_config.patch(
max_autotune=False,
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
self.subgraph.run(*self.example_inputs)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
if self.is_multi_range:
self._codegen_multi_range_dispatch(wrapper)
else:
self._codegen_single_subgraph(wrapper)
def _codegen_single_subgraph(self, wrapper: PythonWrapperCodegen) -> None:
"""Generate code for single subgraph."""
class CodegenGraph:
def __init__(self, graph: GraphLowering):
self.graph = graph
@ -6950,6 +6990,48 @@ class SubgraphBuffer(ExternKernel):
[self.name],
)
def _codegen_multi_range_dispatch(self, wrapper: PythonWrapperCodegen) -> None:
"""Generate Python runtime dispatch for range-specific subgraphs."""
for (range_start, range_end), range_subgraph in self.range_subgraphs:
range_subgraph._codegen_single_subgraph(wrapper)
dispatch_fn_name = f"{self.name}_runtime_dispatch"
assert is_node_sequence(self.inputs)
input_refs = [t.get_name() for t in self.inputs]
wrapper.writeline(f"def {dispatch_fn_name}(args):")
wrapper.writeline(
f" {', '.join([f'arg{i}' for i in range(len(input_refs))])} = args"
)
wrapper.writeline(f" dispatch_size = arg0.size({self.dispatch_dim_index})")
wrapper.writeline(" ")
for i, ((range_start, range_end), range_subgraph) in enumerate(
self.range_subgraphs
):
subgraph_name = range_subgraph.subgraph.name
if i == len(self.range_subgraphs) - 1:
wrapper.writeline(f" else:")
else:
if range_end == float("inf"):
condition = f"dispatch_size >= {range_start}"
else:
condition = f"{range_start} <= dispatch_size <= {range_end}"
if i == 0:
wrapper.writeline(f" if {condition}:")
else:
wrapper.writeline(f" elif {condition}:")
wrapper.writeline(f" return {subgraph_name}(args)")
wrapper.writeline("")
wrapper.writeline(f"{dispatch_fn_name}_args = [{', '.join(input_refs)}]")
wrapper.writeline(
f"({self.name},) = {dispatch_fn_name}({dispatch_fn_name}_args)"
)
class UserDefinedTritonKernel(ExternKernel):
def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]:

View File

@ -201,6 +201,215 @@ def _adapt_user_input_gen_fns(
}
def _merge_identical_implementations(
range_to_best_impl: dict[tuple[int, Union[int, float]], tuple[Callable, dict, str]],
) -> dict[tuple[int, Union[int, float]], tuple[Callable, dict, str]]:
"""Merge consecutive ranges using the same implementation."""
if not range_to_best_impl:
return {}
sorted_ranges = sorted(range_to_best_impl.items(), key=lambda x: x[0][0])
merged = {}
current_range_start, current_range_end = sorted_ranges[0][0]
current_impl, current_kwargs, current_name = sorted_ranges[0][1]
for i in range(1, len(sorted_ranges)):
(next_start, next_end), (next_impl, next_kwargs, next_name) = sorted_ranges[i]
if (
current_impl == next_impl
and current_kwargs == next_kwargs
and current_name == next_name
and next_start == current_range_end + 1
):
current_range_end = next_end
else:
merged[(current_range_start, current_range_end)] = (
current_impl,
current_kwargs,
current_name,
)
current_range_start, current_range_end = next_start, next_end
current_impl, current_kwargs, current_name = (
next_impl,
next_kwargs,
next_name,
)
merged[(current_range_start, current_range_end)] = (
current_impl,
current_kwargs,
current_name,
)
if len(merged) < len(range_to_best_impl):
log.info(
f"Range merging: reduced from {len(range_to_best_impl)} to {len(merged)} ranges"
)
return merged
def _split_points_to_ranges(
split_points: list[int],
) -> list[tuple[int, Union[int, float]]]:
"""Convert split points to inclusive-inclusive ranges.
Example: split_points=[512, 2048] ->
[(1, 512), (513, 2048), (2049, float('inf'))]
"""
ranges = []
start = 1
for split_point in split_points:
ranges.append((start, split_point))
start = split_point + 1
ranges.append((start, float("inf")))
return ranges
def _create_range_input_gen_fn(
base_gen_fn: Callable[[torch.Tensor], torch.Tensor],
dim_index: int,
range_start: int,
range_end: Union[int, float],
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Create input generator that produces tensor with dimension in range."""
def constrained_gen_fn(fake_tensor: torch.Tensor) -> torch.Tensor:
result = base_gen_fn(fake_tensor)
shape = list(result.shape)
# Pick middle of range
if range_end == float("inf"):
target_dim = int(range_start + 100)
else:
target_dim = (int(range_start) + int(range_end)) // 2
target_dim = max(
int(range_start),
min(
target_dim,
int(range_end) - 1 if range_end != float("inf") else target_dim,
),
)
shape[dim_index] = target_dim
return torch.randn(*shape, dtype=result.dtype, device=result.device)
return constrained_gen_fn
def _extract_winning_decomposition_index(
choice_name: str,
decompositions: list[Callable],
) -> int:
"""Extract the decomposition index from winning SubgraphChoiceCaller's name.
The choice name format is: "{op_name}_range_{start}_{end}_{decomp_name}_{counter}"
We parse it to find which decomposition won by matching decomp_name.
Args:
choice_name: Name of the winning SubgraphChoiceCaller
decompositions: List of decomposition functions
Returns:
Index into decompositions list (0-based)
"""
if not choice_name:
log.warning("Empty choice name, defaulting to first decomposition")
return 0
# Try to match decomposition by name
for i, decomp in enumerate(decompositions):
decomp_name = decomp.__name__
# Check if decomposition name appears in choice name
if decomp_name in choice_name:
log.debug(
f"Matched choice '{choice_name}' to decomposition[{i}] '{decomp_name}'"
)
return i
# Fallback: could not determine, use first
log.warning(
f"Could not determine winning decomposition from choice name '{choice_name}', "
f"defaulting to first decomposition"
)
return 0
def _extract_tensor_by_name(
args: tuple[Any, ...],
kwargs: dict[str, Any],
tensor_name: str,
op_overload: torch._ops.OpOverload,
) -> Optional[Any]:
"""Extract a tensor from args/kwargs by parameter name.
Args:
args: Positional arguments
kwargs: Keyword arguments
tensor_name: Name of the parameter to extract
op_overload: OpOverload to get parameter names
Returns:
The tensor (TensorBox/Buffer) if found, None otherwise
"""
import inspect
# Get parameter names from the op's signature
try:
sig = inspect.signature(op_overload)
param_names = list(sig.parameters.keys())
except Exception:
log.warning("Could not get signature for %s, using fallback", op_overload)
# Fallback: assume tensor_name matches position or kwargs
if tensor_name in kwargs:
return kwargs[tensor_name]
return None
# Check if tensor_name is in kwargs
if tensor_name in kwargs:
return kwargs[tensor_name]
# Check if tensor_name is in positional args
if tensor_name in param_names:
param_index = param_names.index(tensor_name)
if param_index < len(args):
return args[param_index]
return None
def _get_dimension_value(tensor: Any, dim_index: int) -> Any:
"""Get the dimension value from a tensor IR node.
Args:
tensor: TensorBox or Buffer IR node
dim_index: Dimension index to extract
Returns:
Dimension value (may be symbolic or concrete)
"""
if hasattr(tensor, "get_size"):
# Buffer has get_size()
shape = tensor.get_size()
elif hasattr(tensor, "data") and hasattr(tensor.data, "get_size"):
# TensorBox wraps data
shape = tensor.data.get_size()
else:
raise RuntimeError(f"Cannot extract shape from {type(tensor)}")
if dim_index >= len(shape):
raise IndexError(
f"dim_index {dim_index} out of range for tensor with {len(shape)} dimensions"
)
return shape[dim_index]
def _create_fallback_choice(
name: str,
default_impl: Callable[..., Any],
@ -224,13 +433,14 @@ def _create_fallback_choice(
def autotune_custom_op(
name: str,
decompositions: list[Callable[..., Any]],
inputs: list[Any],
inputs: list[torch.fx.Node],
non_tensor_args: list[dict[str, Any]],
op_overload: torch._ops.OpOverload,
user_input_gen_fns: Optional[
dict[str, Callable[[torch.Tensor], torch.Tensor]]
] = None,
) -> Union[TensorBox, Any]:
return_choice: bool = False,
) -> Union[TensorBox, Any, tuple[Any, Any]]:
"""Autotune custom operations by comparing multiple decomposition implementations.
Currently supports SINGLE OUTPUT custom ops only.
@ -340,24 +550,305 @@ def autotune_custom_op(
)
from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes
return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)
result = inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name)
if return_choice:
return result, winning_choice
return result
log.debug(
"Winning choice does not support inlining: %s (name=%s)",
getattr(winning_choice, "name", type(winning_choice).__name__),
name,
)
if return_choice:
return selected_result, winning_choice
return selected_result
def _standard_lowering_fn(
processed_configs: list[CustomOpConfig],
default_impl: Callable[..., Any],
name: str,
op_overload: torch._ops.OpOverload,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
args: Any,
kwargs: Any,
) -> Any:
"""Standard autotuning lowering function."""
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
decompositions = []
non_tensor_args = []
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
def _lower_single_impl(
impl: Callable[..., Any],
impl_kwargs: dict[str, Any],
runtime_kwargs: dict[str, Any],
tensor_inputs: list[Any],
name: str,
) -> Any:
"""Lower a single implementation by tracing and inlining it."""
from torch.fx.experimental.proxy_tensor import make_fx
from ..decomposition import select_decomp_table
from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes
def impl_wrapper(*tensors):
return impl(*tensors, **{**runtime_kwargs, **impl_kwargs})
with V.fake_mode:
fake_inputs = tuple(ir_node_to_tensor(inp) for inp in tensor_inputs)
decomposition_table = select_decomp_table()
impl_gm = make_fx(
impl_wrapper,
decomposition_table=decomposition_table,
tracing_mode="symbolic",
)(*fake_inputs)
log.info("Inlining implementation: %s", impl.__name__)
result = inline_subgraph_to_ir_nodes(impl_gm, tensor_inputs, name)
validate_ir(result)
return result
def _range_based_lowering_fn(
processed_configs: list[CustomOpConfig],
default_impl: Callable[..., Any],
name: str,
op_overload: torch._ops.OpOverload,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
tensor_name: str,
dim_index: int,
ranges: list[tuple[int, Union[int, float]]],
args: Any,
kwargs: Any,
) -> Any:
"""Range-based autotuning lowering function."""
log.info("=== Range-based Autotuning for %s ===", name)
log.info("Dispatch on: %s[%d], Ranges: %s", tensor_name, dim_index, ranges)
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Benchmark each range and collect winning implementations
range_to_best_impl = {}
decompositions = []
non_tensor_args = []
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
for range_start, range_end in ranges:
# Create range-specific input generator
range_input_gen_fns = None
if input_gen_fns and tensor_name in input_gen_fns:
base_gen_fn = input_gen_fns[tensor_name]
range_gen_fn = _create_range_input_gen_fn(
base_gen_fn, dim_index, range_start, range_end
)
range_input_gen_fns = {**input_gen_fns, tensor_name: range_gen_fn}
range_name = f"{name}_range_{int(range_start)}_{int(range_end) if range_end != float('inf') else 'inf'}"
# Run autotuning for this range
autotuned_result, winning_choice = autotune_custom_op(
name=range_name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=range_input_gen_fns,
return_choice=True,
)
# Extract winning implementation
choice_name = getattr(winning_choice, "name", "")
winning_idx = _extract_winning_decomposition_index(choice_name, decompositions)
impl = decompositions[winning_idx]
impl_kwargs = non_tensor_args[winning_idx]
range_to_best_impl[(range_start, range_end)] = (
impl,
impl_kwargs,
impl.__name__,
)
log.info(
"Range [%s, %s]: Selected %s",
range_start,
range_end if range_end != float("inf") else "inf",
impl.__name__,
)
log.info("Completed autotuning for %d ranges", len(range_to_best_impl))
# Step 2: Merge consecutive ranges with identical implementations
merged_range_to_best_impl = _merge_identical_implementations(range_to_best_impl)
log.info(
"After merging: %d unique implementations across %d ranges",
len({impl_name for _, _, impl_name in merged_range_to_best_impl.values()}),
len(merged_range_to_best_impl),
)
# Step 3: Check if all ranges merged into one (all ranges use same implementation)
# Since ranges are consecutive and merge function combines consecutive identical impls,
# len == 1 means all ranges use the same impl+kwargs
if len(merged_range_to_best_impl) == 1:
log.info(
"All ranges selected the same implementation - skipping dispatch, using direct inline"
)
single_impl, single_kwargs, _ = next(iter(merged_range_to_best_impl.values()))
return _lower_single_impl(
single_impl, single_kwargs, runtime_kwargs, tensor_inputs, name
)
# Step 4: Create runtime dispatch for multiple implementations
log.info("Creating runtime dispatch for %d ranges", len(merged_range_to_best_impl))
from torch.fx.experimental.proxy_tensor import make_fx
from ..decomposition import select_decomp_table
from ..ir import FixedLayout, SubgraphBuffer, TensorBox
sorted_ranges = sorted(merged_range_to_best_impl.items())
# Trace each implementation independently
range_gms = []
with V.fake_mode:
fake_inputs = tuple(ir_node_to_tensor(inp) for inp in tensor_inputs)
decomposition_table = select_decomp_table()
for (range_start, range_end), (impl_fn, impl_kwargs, _) in sorted_ranges:
log.debug(
"Compiling range [%s, %s]: %s",
range_start,
range_end if range_end != float("inf") else "inf",
impl_fn.__name__,
)
def impl_wrapper(*tensors):
return impl_fn(*tensors, **{**runtime_kwargs, **impl_kwargs})
impl_gm = make_fx(
impl_wrapper,
decomposition_table=decomposition_table,
tracing_mode="symbolic",
)(*fake_inputs)
range_gms.append(((range_start, range_end), impl_gm))
# Get output layout from any implementation (they all have same output shape)
fake_output = impl_gm(*fake_inputs)
output_layout = FixedLayout(
device=fake_output.device,
dtype=fake_output.dtype,
size=fake_output.shape,
stride=fake_output.stride(),
)
log.info("Compiled %d range implementations", len(range_gms))
# Create SubgraphBuffer with multi-range dispatch
result = TensorBox.create(
SubgraphBuffer(
layout=output_layout,
input_nodes=tensor_inputs,
gm=range_gms, # List of (range, gm) tuples triggers multi-range dispatch
example_inputs=list(fake_inputs),
subgraph_name=f"{name}_autotuned",
dispatch_dim_index=dim_index,
)
)
log.info(
"Created SubgraphBuffer with multi-range dispatch (%d ranges)", len(range_gms)
)
validate_ir(result)
return result
def _create_autotuning_lowering(
processed_configs: list[CustomOpConfig],
default_impl: Callable[..., Any],
name: str,
op_overload: torch._ops.OpOverload,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
is_range_based: bool = False,
dispatch_on: Optional[tuple[str, int]] = None,
split_points: Optional[list[int]] = None,
) -> Callable[..., Any]:
"""Create the lowering function for autotuning."""
if not is_range_based:
# Standard autotuning path
@functools.wraps(op_overload)
def standard_lowering_wrapper(*args: Any, **kwargs: Any) -> Any:
return _standard_lowering_fn(
processed_configs=processed_configs,
default_impl=default_impl,
name=name,
op_overload=op_overload,
input_gen_fns=input_gen_fns,
args=args,
kwargs=kwargs,
)
return standard_lowering_wrapper
# Range-based autotuning path
tensor_name, dim_index = dispatch_on
ranges = _split_points_to_ranges(split_points)
@functools.wraps(op_overload)
def range_based_lowering_wrapper(*args: Any, **kwargs: Any) -> Any:
return _range_based_lowering_fn(
processed_configs=processed_configs,
default_impl=default_impl,
name=name,
op_overload=op_overload,
input_gen_fns=input_gen_fns,
tensor_name=tensor_name,
dim_index=dim_index,
ranges=ranges,
args=args,
kwargs=kwargs,
)
return range_based_lowering_wrapper
def register_custom_op_autotuning(
custom_op: torch._library.custom_ops.CustomOpDef,
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
name: Optional[str] = None,
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
dispatch_on: Optional[tuple[str, int]] = None,
split_points: Optional[list[int]] = None,
) -> None:
"""Register custom op for autotuning with custom_op configs where each config
specifies a decomposition implementation function with its parameter values.
It also supports Range-based autotuning to benchmark per range and generate
runtime dispatch.
Args:
custom_op: Custom operation (decorated function from @torch.library.custom_op)
@ -370,6 +861,7 @@ def register_custom_op_autotuning(
def my_attention(query, key, value, head_dim=32):
...
Standard Example:
register_custom_op_autotuning(
my_attention,
configs=[
@ -383,6 +875,14 @@ def register_custom_op_autotuning(
"value": lambda fake: torch.randn_like(fake, device='cuda'),
},
)
Range-based Example:
register_custom_op_autotuning(
my_op,
configs=[CustomOpConfig(impl1), CustomOpConfig(impl2), CustomOpConfig(impl3)],
dispatch_on=("x", 1), # Dispatch on x[1]
split_points=[512, 2048], # Creates ranges: [1,512], [513,2048], [2049,inf]
)
"""
from torch._library.custom_ops import CustomOpDef
@ -413,34 +913,30 @@ def register_custom_op_autotuning(
if name is None:
name = f"{op_overload._name}_autotuned"
@functools.wraps(op_overload)
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
# Extract tensor inputs and non-tensor parameters (runtime kwargs)
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
# Validate range-based parameters
is_range_based = dispatch_on is not None or split_points is not None
if is_range_based:
if dispatch_on is None or split_points is None:
raise ValueError(
"Both dispatch_on and split_points must be specified for range-based autotuning"
)
if not isinstance(dispatch_on, tuple) or len(dispatch_on) != 2:
raise ValueError("dispatch_on must be a tuple of (tensor_name, dim_index)")
if not isinstance(split_points, list) or len(split_points) == 0:
raise ValueError("split_points must be a non-empty list of integers")
if sorted(split_points) != split_points:
raise ValueError("split_points must be sorted in ascending order")
# Prepare decompositions and kwargs by merging config params with runtime kwargs
decompositions = []
non_tensor_args = []
# Create and register the lowering function
lowering_fn = _create_autotuning_lowering(
processed_configs=processed_configs,
default_impl=default_impl,
name=name,
op_overload=op_overload,
input_gen_fns=input_gen_fns,
is_range_based=is_range_based,
dispatch_on=dispatch_on,
split_points=split_points,
)
for cfg in processed_configs:
decomp = cfg.get_decomposition(default_impl=default_impl)
decompositions.append(decomp)
# Merge config params with runtime kwargs (runtime takes precedence)
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
non_tensor_args.append(merged_kwargs)
result = autotune_custom_op(
name=name,
decompositions=decompositions,
inputs=tensor_inputs,
non_tensor_args=non_tensor_args,
op_overload=op_overload,
user_input_gen_fns=input_gen_fns,
)
validate_ir(result)
return result
lowerings[op_overload] = autotuning_lowering
lowerings[op_overload] = lowering_fn