mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
Compare commits
26 Commits
codex/add-
...
ciflow/ind
| Author | SHA1 | Date | |
|---|---|---|---|
| d3df658752 | |||
| 25db928d8b | |||
| f81e55d2be | |||
| 237bfde80e | |||
| b01c23ea73 | |||
| 84457f6f10 | |||
| ef2654092c | |||
| af01078a3a | |||
| 191163a1d5 | |||
| 16a85e83ae | |||
| ba3a90b26c | |||
| bd36441960 | |||
| cf46d7064d | |||
| d183d9f36e | |||
| 3aab26a012 | |||
| a8b92f8627 | |||
| b6a1f1dcaf | |||
| c4d56497f4 | |||
| 540038b2a6 | |||
| 25ef7dc9a5 | |||
| 8c7614f086 | |||
| 6475254b78 | |||
| 55c2d5ed7a | |||
| cabc6be017 | |||
| ba623bf246 | |||
| c9ede96ac5 |
564
test/inductor/test_custom_op_autotune.py
Normal file
564
test/inductor/test_custom_op_autotune.py
Normal file
@ -0,0 +1,564 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Tests for custom operation autotuning with PyTorch Inductor.
|
||||
|
||||
Users can register custom ops with multiple decomposition implementations and let
|
||||
Inductor automatically select the best performing variant. Key features tested:
|
||||
|
||||
- Name-based input generators (use argument names instead of indices)
|
||||
- Dynamic shape handling across multiple compilations
|
||||
- Parametric tuning with tuning_knob for combinatorial parameter exploration
|
||||
- Numerical correctness and performance validation
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.kernel.custom_op import (
|
||||
CustomOpConfig,
|
||||
register_custom_op_autotuning,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
class TestCustomOpAutoTune(TestCase):
|
||||
"""Test custom operation autotuning functionality."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test environment with appropriate device and dtype."""
|
||||
super().setUp()
|
||||
self.device = "cuda" if HAS_GPU else "cpu"
|
||||
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
||||
|
||||
def _create_test_configs(self):
|
||||
"""Create common test configurations for different sizes."""
|
||||
return [
|
||||
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
|
||||
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
|
||||
]
|
||||
|
||||
def _run_autotune_test(self, op_object, inputs, expected, test_name):
|
||||
"""Shared test infrastructure for autotuning tests."""
|
||||
|
||||
@torch.compile
|
||||
def test_model(*args):
|
||||
return op_object(*args)
|
||||
|
||||
torch._dynamo.reset()
|
||||
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
|
||||
|
||||
with config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends=autotune_backends,
|
||||
fx_graph_cache=False,
|
||||
benchmark_kernel=True,
|
||||
):
|
||||
compiled_result = test_model(*inputs)
|
||||
|
||||
self.assertEqual(
|
||||
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
compiled_result,
|
||||
expected,
|
||||
rtol=2e-1,
|
||||
atol=5e-1,
|
||||
msg=f"{test_name} numerical mismatch",
|
||||
)
|
||||
|
||||
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
|
||||
"""Utility to assert that all implementations produce equivalent results."""
|
||||
implementations = [(func.__name__, func) for func in decompositions]
|
||||
results = {}
|
||||
for name, impl in implementations:
|
||||
result = impl(*inputs)
|
||||
results[name] = result
|
||||
|
||||
# Basic sanity checks
|
||||
self.assertTrue(
|
||||
torch.isfinite(result).all(),
|
||||
f"{op_name} {name} produced non-finite values",
|
||||
)
|
||||
|
||||
# Verify numerical equivalence
|
||||
reference_name, reference_result = next(iter(results.items()))
|
||||
for name, result in results.items():
|
||||
if name != reference_name:
|
||||
rtol = 1e-1 if "Approximated" in name else 1e-2
|
||||
atol = 1e-1 if "Approximated" in name else 1e-2
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
reference_result,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"{op_name} {name} differs from {reference_name}",
|
||||
)
|
||||
|
||||
def _create_rmsnorm_inputs(self, batch_size=32, seq_len=2048, hidden_dim=512):
|
||||
"""Create test inputs for RMSNorm operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
return input_tensor, weight
|
||||
|
||||
def _create_mlp_inputs(
|
||||
self,
|
||||
batch_size=2,
|
||||
seq_len=32,
|
||||
hidden_dim=512,
|
||||
intermediate_dim=1024,
|
||||
output_dim=256,
|
||||
):
|
||||
"""Create test inputs for MLP operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
gate_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
up_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
down_weight = torch.randn(
|
||||
intermediate_dim,
|
||||
output_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
return input_tensor, gate_weight, up_weight, down_weight
|
||||
|
||||
@skipIfXpu
|
||||
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
|
||||
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
|
||||
test_op_name = f"test_lib::rmsnorm_{id(self)}"
|
||||
|
||||
def rmsnorm_decomposition1(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""Variance-based approach: compute variance then rsqrt."""
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
return x * rstd * weight
|
||||
|
||||
def rmsnorm_decomposition2(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
x_var = x
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
return x
|
||||
|
||||
def rmsnorm_decomposition3(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
x_squared = x.pow(2)
|
||||
variance = x_squared.mean(dim=-1, keepdim=True)
|
||||
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
normalized = x * rstd
|
||||
|
||||
if weight is not None:
|
||||
normalized = normalized * weight
|
||||
return normalized
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_rmsnorm_op(
|
||||
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.rms_norm(
|
||||
input_tensor, input_tensor.shape[-1:], weight, eps=eps
|
||||
)
|
||||
|
||||
@test_rmsnorm_op.register_fake
|
||||
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
decompositions = [
|
||||
rmsnorm_decomposition1,
|
||||
rmsnorm_decomposition2,
|
||||
rmsnorm_decomposition3,
|
||||
]
|
||||
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[CustomOpConfig(decomp) for decomp in decompositions],
|
||||
name="test_rmsnorm_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
|
||||
"weight": lambda weight: torch.ones_like(weight, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
# Test multiple shapes to verify dynamic shape handling
|
||||
test_shapes = [(2, 16, 128), (8, 32, 256)]
|
||||
|
||||
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
|
||||
# Test numerical equivalence for all decompositions
|
||||
self._assert_implementations_equivalent(
|
||||
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
expected = rmsnorm_decomposition1(input_tensor, weight)
|
||||
self._run_autotune_test(
|
||||
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
def test_mlp_custom_op_autotune(self):
|
||||
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
|
||||
test_op_name = f"test_lib::mlp_{id(self)}"
|
||||
|
||||
def mlp_variants(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""MLP implementation with different computational approaches controlled by method parameter."""
|
||||
|
||||
if method == 0:
|
||||
# Separate matmuls: standard implementation with torch.matmul
|
||||
gate_proj = torch.matmul(input_tensor, gate_weight)
|
||||
up_proj = torch.matmul(input_tensor, up_weight)
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
elif method == 1:
|
||||
# Batched approach: uses torch.mm with reshaped tensors
|
||||
batch_shape = input_tensor.shape[:-1]
|
||||
hidden_dim = input_tensor.shape[-1]
|
||||
output_dim = down_weight.shape[-1]
|
||||
|
||||
input_2d = input_tensor.view(-1, hidden_dim)
|
||||
|
||||
gate_proj = torch.mm(input_2d, gate_weight)
|
||||
up_proj = torch.mm(input_2d, up_weight)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
output_2d = torch.mm(gated, down_weight)
|
||||
|
||||
return output_2d.view(*batch_shape, output_dim)
|
||||
|
||||
elif method == 2:
|
||||
# Fused weights approach: concatenate then split weights
|
||||
# Concatenate gate and up weights for one matrix multiply
|
||||
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
|
||||
fused_proj = torch.matmul(input_tensor, fused_weight)
|
||||
|
||||
intermediate_dim = gate_weight.shape[1]
|
||||
gate_proj, up_proj = fused_proj.split(
|
||||
[intermediate_dim, intermediate_dim], dim=-1
|
||||
)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_mlp_op(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=method
|
||||
)
|
||||
|
||||
@test_mlp_op.register_fake
|
||||
def _(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
):
|
||||
return torch.empty(
|
||||
input_tensor.shape[:-1] + (down_weight.shape[-1],),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Use explicit configs with method parameter as tuning knob
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(method=1), # Batched approach
|
||||
CustomOpConfig(method=2), # Fused weights
|
||||
],
|
||||
name="test_mlp_autotuned",
|
||||
input_gen_fns={
|
||||
"input_tensor": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
"gate_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"up_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"down_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs using the original helper method
|
||||
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
|
||||
|
||||
# Test that all method variants produce numerically equivalent results
|
||||
expected = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=0
|
||||
)
|
||||
|
||||
for method in [1, 2]:
|
||||
result = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=method
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Method {method} not equivalent to method 0",
|
||||
)
|
||||
|
||||
# Test autotuning - all should be mathematically equivalent
|
||||
self._run_autotune_test(
|
||||
op_object,
|
||||
(input_tensor, gate_weight, up_weight, down_weight),
|
||||
expected,
|
||||
"MLP",
|
||||
)
|
||||
|
||||
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
|
||||
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
|
||||
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
|
||||
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
|
||||
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
return a, b
|
||||
|
||||
@skipIfXpu
|
||||
def test_decompose_k_custom_op_autotune(self):
|
||||
"""Test decompose_k autotuning with parameter tuning for k_splits values using decomposition functions."""
|
||||
test_op_name = f"test_lib::decompose_k_{id(self)}"
|
||||
|
||||
def decompose_k_implementation(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
"""Matrix multiply with k-way decomposition - Python implementation."""
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
k = a.shape[1]
|
||||
|
||||
k_parts = k // k_splits
|
||||
B = k_splits
|
||||
|
||||
a_reshaped = torch.permute(
|
||||
a.reshape(m, B, k_parts), (1, 0, 2)
|
||||
) # [B, m, k_parts]
|
||||
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
|
||||
|
||||
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
|
||||
|
||||
return torch.sum(result, dim=0) # [m, n]
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_decompose_k_op(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
|
||||
return decompose_k_implementation(a, b, k_splits)
|
||||
|
||||
@test_decompose_k_op.register_fake
|
||||
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
|
||||
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Register autotuning with different k_splits values using decomposition function
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(k_splits=32),
|
||||
CustomOpConfig(k_splits=64),
|
||||
CustomOpConfig(k_splits=128),
|
||||
CustomOpConfig(k_splits=256),
|
||||
],
|
||||
name="test_decompose_k_autotuned",
|
||||
input_gen_fns={
|
||||
"a": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
"b": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
},
|
||||
)
|
||||
|
||||
a, b = self._create_decompose_k_inputs()
|
||||
expected = a @ b
|
||||
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
|
||||
|
||||
@skipIfXpu
|
||||
def test_multi_parameter_tuning(self):
|
||||
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
|
||||
op_name = f"test_lib::multi_param_{id(self)}"
|
||||
|
||||
def multi_param_scaling(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""Different scaling approaches controlled by scale_mode parameter."""
|
||||
if scale_mode == 1:
|
||||
# Simple broadcasting
|
||||
return x * factor
|
||||
elif scale_mode == 2:
|
||||
# Process in chunks
|
||||
batch_size, seq_len = x.shape[:2]
|
||||
chunks = []
|
||||
for start in range(0, seq_len, chunk_size):
|
||||
end = min(start + chunk_size, seq_len)
|
||||
chunk = x[:, start:end]
|
||||
chunks.append(chunk * factor)
|
||||
return torch.cat(chunks, dim=1)
|
||||
elif scale_mode == 3:
|
||||
# Using einsum for scaling
|
||||
return torch.einsum("...i,i->...i", x, factor)
|
||||
|
||||
@torch.library.custom_op(op_name, mutates_args=())
|
||||
def multi_param_op(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
return multi_param_scaling(x, factor, scale_mode, chunk_size)
|
||||
|
||||
@multi_param_op.register_fake
|
||||
def _(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
):
|
||||
return torch.empty_like(x)
|
||||
|
||||
lib_name, op_suffix = op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
|
||||
|
||||
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(scale_mode=1), # Broadcast
|
||||
CustomOpConfig(scale_mode=2, chunk_size=16), # Chunked 16
|
||||
CustomOpConfig(scale_mode=2, chunk_size=32), # Chunked 32
|
||||
CustomOpConfig(scale_mode=3), # Einsum
|
||||
],
|
||||
name="multi_param_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
"factor": lambda t: torch.ones(
|
||||
t.shape[-1], device=self.device, dtype=t.dtype
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
|
||||
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
|
||||
|
||||
# Verify numerical equivalence across all approaches
|
||||
expected_result = test_x * test_factor
|
||||
|
||||
# Test each scale_mode variant
|
||||
configs = [
|
||||
(1, 16), # broadcast, chunk_size ignored
|
||||
(2, 16), # chunked with size 16
|
||||
(2, 32), # chunked with size 32
|
||||
(3, 16), # einsum, chunk_size ignored
|
||||
]
|
||||
|
||||
for i, (scale_mode, chunk_size) in enumerate(configs):
|
||||
result = multi_param_scaling(
|
||||
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected_result,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
self._run_autotune_test(
|
||||
op_object, (test_x, test_factor), expected_result, "MultiParam"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
@ -8,6 +8,7 @@ from torch._inductor import ir
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import (
|
||||
Buffer,
|
||||
FixedLayout,
|
||||
get_free_symbols,
|
||||
get_symbolic_inputs,
|
||||
gm_original_output_strides,
|
||||
@ -110,7 +111,11 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
bm_func([*sym_inputs, *args])
|
||||
if config.profile_bandwidth_with_do_bench_using_profiling:
|
||||
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
if self.layout.device.type == "cpu":
|
||||
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
else:
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
def hash_key(self) -> str:
|
||||
return "-".join(
|
||||
@ -181,9 +186,11 @@ class SubgraphTemplate(KernelTemplate):
|
||||
Generate a SubgraphChoiceCaller instance for autotuning.
|
||||
|
||||
Args:
|
||||
name: The name for this subgraph choice
|
||||
input_nodes: List of input nodes to the subgraph
|
||||
layout: Memory layout information for the output
|
||||
example_inputs: Example tensor inputs used to trace and benchmark the subgraph
|
||||
make_fx_graph: Callable that creates the FX graph for this subgraph
|
||||
description: Optional description of this choice
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
@ -197,3 +204,165 @@ class SubgraphTemplate(KernelTemplate):
|
||||
description=description,
|
||||
make_fx_graph=make_fx_graph,
|
||||
)
|
||||
|
||||
def generate_custom_op_choices(
|
||||
self,
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
input_nodes: list[Buffer],
|
||||
non_tensor_args: list[dict[str, Any]],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> list[SubgraphChoiceCaller]:
|
||||
"""
|
||||
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
|
||||
|
||||
This method extends SubgraphTemplate to support custom op decompositions,
|
||||
allowing multiple implementations to compete in autotuning.
|
||||
|
||||
Args:
|
||||
name: Base name for the choices
|
||||
decompositions: List of decomposition functions to compete in autotuning
|
||||
input_nodes: List of tensor inputs. All tensor arguments must be passed here.
|
||||
non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition.
|
||||
default_impl: Default implementation for layout inference
|
||||
|
||||
Returns:
|
||||
List of SubgraphChoiceCaller instances for autotuning
|
||||
"""
|
||||
if not decompositions:
|
||||
return []
|
||||
|
||||
assert len(decompositions) == len(non_tensor_args), (
|
||||
f"decompositions and non_tensor_args must have same length, "
|
||||
f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
|
||||
)
|
||||
|
||||
# Infer layouts and ensure layout consistency for fair autotuning comparison
|
||||
layouts = [
|
||||
self._infer_custom_op_layout(input_nodes, decomp, kwargs, default_impl)
|
||||
for decomp, kwargs in zip(decompositions, non_tensor_args)
|
||||
]
|
||||
|
||||
# Validate all decompositions produce equivalent layouts for fair comparison
|
||||
self._validate_layout_equivalence(name, decompositions, layouts)
|
||||
layout = layouts[0] # All layouts are now validated to be equivalent
|
||||
|
||||
choices: list[SubgraphChoiceCaller] = []
|
||||
for decomp, decomp_kwargs in zip(decompositions, non_tensor_args):
|
||||
# Create make_fx_graph function for this decomposition
|
||||
import functools
|
||||
|
||||
def make_fx_graph(
|
||||
*args: Any,
|
||||
decomp: Callable[..., Any] = decomp,
|
||||
decomp_kwargs: dict[str, Any] = decomp_kwargs,
|
||||
) -> Any:
|
||||
# decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
|
||||
|
||||
# Generate descriptive name for this variant
|
||||
variant_name = self._generate_variant_name(decomp, decomp_kwargs)
|
||||
|
||||
choice = self.generate(
|
||||
name=f"{name}_{variant_name}",
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_fx_graph=make_fx_graph,
|
||||
description=f"CustomOp {decomp.__name__}",
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return choices
|
||||
|
||||
def _generate_variant_name(
|
||||
self, decomp: Callable[..., Any], kwargs: dict[str, Any]
|
||||
) -> str:
|
||||
"""Generate a descriptive name for a decomposition variant with its parameters."""
|
||||
base_name = decomp.__name__
|
||||
if not kwargs:
|
||||
return base_name
|
||||
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items()))
|
||||
return f"{base_name}_{param_suffix}"
|
||||
|
||||
def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None:
|
||||
"""Validate that kwargs contains only non-tensor arguments."""
|
||||
for key, value in kwargs.items():
|
||||
assert not isinstance(value, (torch.Tensor, Buffer)), (
|
||||
f"kwargs['{key}'] contains tensor {type(value)}. "
|
||||
f"Tensor arguments should be in input_nodes, not kwargs. "
|
||||
f"Only scalar/non-tensor parameters should be in kwargs."
|
||||
)
|
||||
|
||||
def _validate_layout_equivalence(
|
||||
self,
|
||||
op_name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
layouts: list[Layout],
|
||||
) -> None:
|
||||
"""Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning."""
|
||||
if not layouts:
|
||||
return
|
||||
|
||||
reference = layouts[0]
|
||||
for i, layout in enumerate(layouts[1:], start=1):
|
||||
if (layout.device, layout.dtype, layout.size, layout.stride) != (
|
||||
reference.device,
|
||||
reference.dtype,
|
||||
reference.size,
|
||||
reference.stride,
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Layout mismatch in custom op '{op_name}': "
|
||||
f"decomposition '{decompositions[i].__name__}' produces "
|
||||
f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) "
|
||||
f"but '{decompositions[0].__name__}' produces "
|
||||
f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})"
|
||||
)
|
||||
|
||||
def _infer_custom_op_layout(
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
function_decomposition: Callable[..., Any],
|
||||
kwargs: dict[str, Any],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> Layout:
|
||||
"""Infer output layout for custom ops using the default implementation when available.
|
||||
Note that the Subgraph assumes custom ops return exactly one tensor output.
|
||||
TODO: Add support for multiple output custom ops.
|
||||
"""
|
||||
import functools
|
||||
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
# Assert kwargs contain only non-tensor arguments
|
||||
self._validate_non_tensor_kwargs(kwargs)
|
||||
|
||||
with V.fake_mode:
|
||||
example_inputs = []
|
||||
for inp in input_nodes:
|
||||
raw_shape = inp.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
|
||||
)
|
||||
example_inputs.append(fake_tensor)
|
||||
|
||||
fn = functools.partial(function_decomposition, **kwargs)
|
||||
output = fn(*example_inputs)
|
||||
|
||||
# Assert single output
|
||||
assert isinstance(output, torch.Tensor), (
|
||||
f"Expected single tensor output, got {type(output)}. "
|
||||
f"Multi-output custom ops not yet supported in autotuning."
|
||||
)
|
||||
|
||||
return FixedLayout(
|
||||
device=output.device,
|
||||
dtype=output.dtype,
|
||||
size=output.shape,
|
||||
stride=output.stride(),
|
||||
)
|
||||
|
||||
434
torch/_inductor/kernel/custom_op.py
Normal file
434
torch/_inductor/kernel/custom_op.py
Normal file
@ -0,0 +1,434 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import _ops
|
||||
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
|
||||
from torch._inductor.lowering import lowerings, validate_ir
|
||||
from torch._inductor.select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomOpConfig:
|
||||
"""Config for custom op autotuning.
|
||||
|
||||
Specifies optional decomposition function with parameter values.
|
||||
Each config creates exactly one variant.
|
||||
|
||||
Args:
|
||||
decomposition: Optional functions to autotune. If not provided, default will be used.
|
||||
**params: Parameters passed to the function
|
||||
|
||||
Examples:
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
|
||||
CustomOpConfig(head_dim=32, method='chunked')
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decomposition: Optional[Callable[..., Any]] = None,
|
||||
**params: Any,
|
||||
):
|
||||
if decomposition is not None and not callable(decomposition):
|
||||
raise TypeError(
|
||||
f"decomposition must be callable, got {type(decomposition)}"
|
||||
)
|
||||
|
||||
self.decomposition = decomposition
|
||||
self.params = params
|
||||
|
||||
def get_decomposition(
|
||||
self, default_impl: Optional[Callable[..., Any]] = None
|
||||
) -> Callable[..., Any]:
|
||||
"""Return the decomposition function for this config.
|
||||
When decomposition is not specified, return the default implementation
|
||||
from the custom op's registration.
|
||||
"""
|
||||
if self.decomposition is not None:
|
||||
return self.decomposition
|
||||
|
||||
# If no decomposition specified, get Python implementation from custom op
|
||||
if default_impl and isinstance(default_impl, (_ops.OpOverload, str)):
|
||||
from torch._library.custom_ops import _maybe_get_opdef
|
||||
|
||||
op_def = _maybe_get_opdef(default_impl)
|
||||
if op_def is not None and hasattr(op_def, "_init_fn"):
|
||||
return op_def._init_fn
|
||||
|
||||
raise TypeError(
|
||||
f"Could not extract Python implementation from {default_impl}. "
|
||||
f"Please register customop or provide a decomposition function."
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get the name for this config variant."""
|
||||
param_suffix = (
|
||||
"_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
|
||||
if self.params
|
||||
else ""
|
||||
)
|
||||
|
||||
base_name = self.decomposition.__name__ if self.decomposition else "default"
|
||||
|
||||
return f"{base_name}_{param_suffix}" if param_suffix else base_name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.params:
|
||||
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
|
||||
decomp_name = (
|
||||
self.decomposition.__name__ if self.decomposition else "default"
|
||||
)
|
||||
return f"CustomOpConfig({decomp_name}, {params_str})"
|
||||
decomp_name = self.decomposition.__name__ if self.decomposition else "default"
|
||||
return f"CustomOpConfig({decomp_name})"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"autotune_custom_op",
|
||||
"register_custom_op_autotuning",
|
||||
"CustomOpConfig",
|
||||
]
|
||||
|
||||
|
||||
def _extract_tensor_inputs(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Extract tensor inputs from mixed args/kwargs.
|
||||
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
|
||||
Non-tensor kwargs are later functools.partial'd into decomposition functions.
|
||||
|
||||
Args:
|
||||
args: Positional arguments (mix of tensors and scalars)
|
||||
kwargs: Keyword arguments (mix of tensors and scalars)
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor_inputs_list, non_tensor_kwargs)
|
||||
"""
|
||||
tensor_inputs = []
|
||||
non_tensor_kwargs = {}
|
||||
|
||||
# Process args and kwargs: separate tensor inputs and non tensor args
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(arg)
|
||||
else:
|
||||
# Add non-tensor positional args to kwargs with generated names
|
||||
non_tensor_kwargs[f"arg_{i}"] = arg
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(value)
|
||||
else:
|
||||
non_tensor_kwargs[key] = value
|
||||
|
||||
return tensor_inputs, non_tensor_kwargs
|
||||
|
||||
|
||||
def _merge_config_and_runtime_kwargs(
|
||||
config_params: dict[str, Any],
|
||||
runtime_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Merge config parameters with runtime kwargs. Runtime kwargs take precedence.
|
||||
If there are conflicts, log a warning and use runtime value.
|
||||
|
||||
Args:
|
||||
config_params: Parameters from CustomOpConfig
|
||||
runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs
|
||||
|
||||
Returns:
|
||||
Merged kwargs dictionary with runtime values taking precedence
|
||||
"""
|
||||
merged_kwargs = config_params.copy()
|
||||
|
||||
# Check for conflicts and let runtime kwargs dominate
|
||||
conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys())
|
||||
|
||||
for key in conflicts:
|
||||
log.warning(
|
||||
"Parameter '%s' specified both in CustomOpConfig (%s) "
|
||||
"and at runtime (%s). Using runtime value.",
|
||||
key,
|
||||
config_params[key],
|
||||
runtime_kwargs[key],
|
||||
)
|
||||
|
||||
# Runtime kwargs override config params
|
||||
merged_kwargs.update(runtime_kwargs)
|
||||
|
||||
return merged_kwargs
|
||||
|
||||
|
||||
def _create_user_input_gen_fns(
|
||||
inputs: list[Any],
|
||||
arg_names: list[str],
|
||||
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
|
||||
) -> dict[int, Callable[[Any], torch.Tensor]]:
|
||||
"""Convert user input generators from name-based to index-based format.
|
||||
Inductor autotune's input_gen_fns expects index of arg_names as key.
|
||||
|
||||
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
||||
index_based_fns = {}
|
||||
|
||||
for name, gen_fn in user_input_gen_fns.items():
|
||||
if name in name_to_index:
|
||||
index_based_fns[name_to_index[name]] = gen_fn
|
||||
else:
|
||||
log.warning(
|
||||
"Unknown argument name '%s' in input_gen_fns. "
|
||||
"Available argument names: %s",
|
||||
name,
|
||||
list(name_to_index.keys()),
|
||||
)
|
||||
|
||||
def create_internal_input_gen_fn(
|
||||
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
|
||||
) -> Callable[[Any], torch.Tensor]:
|
||||
"""Create internal input generator that converts IR buffer to user's fake tensor."""
|
||||
|
||||
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
|
||||
raw_shape = ir_buffer.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
|
||||
)
|
||||
return user_function(fake_tensor)
|
||||
|
||||
return internal_input_gen_fn
|
||||
|
||||
return {
|
||||
i: create_internal_input_gen_fn(
|
||||
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
|
||||
)
|
||||
for i, user_gen_fn in index_based_fns.items()
|
||||
if i < len(inputs)
|
||||
}
|
||||
|
||||
|
||||
def _create_fallback_choice(
|
||||
name: str,
|
||||
default_impl: Callable[..., Any],
|
||||
fake_output: torch.Tensor,
|
||||
kwargs: dict[str, Any],
|
||||
) -> ExternKernelChoice:
|
||||
"""Create fallback choice for default implementation."""
|
||||
|
||||
def fallback_wrapper(*args: Any) -> Any:
|
||||
return default_impl(*args, **kwargs)
|
||||
|
||||
return ExternKernelChoice(
|
||||
kernel=fallback_wrapper,
|
||||
name=f"{name}_fallback_default",
|
||||
has_out_variant=False,
|
||||
op_overload=default_impl,
|
||||
use_fallback_kernel=True,
|
||||
)
|
||||
|
||||
|
||||
def autotune_custom_op(
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
inputs: list[Any],
|
||||
non_tensor_args: list[dict[str, Any]],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
user_input_gen_fns: Optional[
|
||||
dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
] = None,
|
||||
) -> Union[TensorBox, Any]:
|
||||
"""Autotune custom operations by comparing multiple decomposition implementations.
|
||||
|
||||
Currently supports SINGLE OUTPUT custom ops only.
|
||||
TODO: Add support for multiple output custom ops (tuple/list returns).
|
||||
|
||||
This function generates multiple implementation choices for a custom operation and
|
||||
uses Inductor's autotuning system to select the best performing variant at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the autotuning operation
|
||||
decompositions: List of alternative implementation functions to benchmark
|
||||
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
|
||||
non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg
|
||||
default_impl: Original custom op implementation used as fallback
|
||||
user_input_gen_fns: Optional custom input generators for benchmarking.
|
||||
Maps input indices to functions that take fake tensors
|
||||
and return real tensors for performance measurement.
|
||||
|
||||
Returns:
|
||||
IR node representing the optimized operation result
|
||||
|
||||
Raises:
|
||||
TypeError: If decompositions is not a list/tuple
|
||||
RuntimeError: If no inputs or no valid choices generated
|
||||
"""
|
||||
if not isinstance(decompositions, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
|
||||
|
||||
if len(decompositions) != len(non_tensor_args):
|
||||
raise ValueError(
|
||||
f"decompositions and non_tensor_args must have same length, "
|
||||
f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs"
|
||||
)
|
||||
|
||||
template = SubgraphTemplate(name=name)
|
||||
choices = template.generate_custom_op_choices(
|
||||
name=name,
|
||||
decompositions=decompositions,
|
||||
input_nodes=list(inputs),
|
||||
non_tensor_args=non_tensor_args,
|
||||
)
|
||||
|
||||
# Add default implementation as fallback
|
||||
if default_impl and hasattr(default_impl, "_op"):
|
||||
fallback_name = f"{name}_fallback_default"
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
# Skip if extern_kernel already registered to avoid duplicate registration error
|
||||
if not hasattr(extern_kernels, fallback_name):
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
fallback_kwargs = non_tensor_args[0] if non_tensor_args else {}
|
||||
fake_output = default_impl(*fake_inputs, **fallback_kwargs)
|
||||
|
||||
fallback_choice = _create_fallback_choice(
|
||||
name, default_impl, fake_output, fallback_kwargs
|
||||
)
|
||||
fallback_choice.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=FixedLayout(
|
||||
device=fake_output.device,
|
||||
dtype=fake_output.dtype,
|
||||
size=fake_output.shape,
|
||||
stride=fake_output.stride(),
|
||||
),
|
||||
)
|
||||
|
||||
if not choices:
|
||||
raise RuntimeError(f"No valid choices generated for {name}")
|
||||
|
||||
# Convert user input generation functions to internal format
|
||||
input_gen_fns = {}
|
||||
if user_input_gen_fns:
|
||||
import inspect
|
||||
|
||||
arg_names = (
|
||||
list(inspect.signature(decompositions[0]).parameters.keys())
|
||||
if decompositions
|
||||
else []
|
||||
)
|
||||
input_gen_fns = _create_user_input_gen_fns(
|
||||
inputs, arg_names, user_input_gen_fns
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name=name,
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=choices[0].layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
|
||||
name: Optional[str] = None,
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
) -> None:
|
||||
"""Register custom op for autotuning with custom_op configs where each config
|
||||
specifies a decomposition implementation function with its parameter values.
|
||||
|
||||
Args:
|
||||
custom_op: Custom operation to register
|
||||
configs: List of CustomOpConfig objects or callable functions
|
||||
name: Operation name (default: "{op_name}_autotuned")
|
||||
input_gen_fns: Custom input generators for benchmarking
|
||||
|
||||
Examples:
|
||||
register_custom_op_autotuning(
|
||||
torch.ops.mylib.attention.default,
|
||||
configs=[
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
|
||||
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
|
||||
CustomOpConfig(fallback_impl), # No params
|
||||
],
|
||||
input_gen_fns={
|
||||
"query": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"key": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"value": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
}
|
||||
)
|
||||
"""
|
||||
if not isinstance(configs, (list, tuple)):
|
||||
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
|
||||
|
||||
processed_configs = []
|
||||
for config in configs:
|
||||
if isinstance(config, CustomOpConfig):
|
||||
processed_configs.append(config)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Each config must be a CustomOpConfig object, got {type(config)}"
|
||||
)
|
||||
|
||||
if not processed_configs:
|
||||
raise ValueError("At least one config must be provided")
|
||||
|
||||
if name is None:
|
||||
name = f"{custom_op._name}_autotuned"
|
||||
|
||||
@functools.wraps(custom_op)
|
||||
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
|
||||
# Extract tensor inputs and non-tensor parameters (runtime kwargs)
|
||||
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
# Prepare decompositions and kwargs by merging customop config params with runtime kwargs
|
||||
decompositions = []
|
||||
non_tensor_args = []
|
||||
|
||||
for config in processed_configs:
|
||||
decomp = config.get_decomposition(default_impl=custom_op)
|
||||
decompositions.append(decomp)
|
||||
|
||||
# Merge config params with runtime kwargs (runtime takes precedence)
|
||||
merged_kwargs = _merge_config_and_runtime_kwargs(
|
||||
config.params, runtime_kwargs
|
||||
)
|
||||
non_tensor_args.append(merged_kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=non_tensor_args,
|
||||
default_impl=custom_op,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
|
||||
lowerings[custom_op] = autotuning_lowering
|
||||
@ -2919,7 +2919,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
|
||||
# if timings is empty, we really have no choice but to return a semi-random
|
||||
# choice. returning the first `ExternKernelCaller` is probably the safest bet
|
||||
# in this case, since it will generally be the ATen kernel. if there are no
|
||||
@ -3524,6 +3523,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
|
||||
if config.autotune_num_choices_displayed == 0:
|
||||
return
|
||||
|
||||
# when autotune_num_choices_displayed is None, [:None] means all
|
||||
n = config.autotune_num_choices_displayed
|
||||
top_k = sorted(timings, key=timings.__getitem__)[:n]
|
||||
|
||||
Reference in New Issue
Block a user