mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
23 Commits
cpp-docs-d
...
ciflow/ind
Author | SHA1 | Date | |
---|---|---|---|
c992543eca | |||
1868744409 | |||
06890019d2 | |||
8daee127db | |||
58590fb37f | |||
69688d49a9 | |||
584bd31a10 | |||
1c77a09da5 | |||
60cd4b5730 | |||
fd6938766a | |||
89283b4fb9 | |||
13bedfdfd3 | |||
ea6d1ff025 | |||
579ff95850 | |||
6efa559a0e | |||
aae722c5a8 | |||
ffc17077c9 | |||
75bf74d926 | |||
ce751dcb45 | |||
46759ac0d2 | |||
7206224dc8 | |||
807e35f76c | |||
9303113015 |
572
test/inductor/test_custom_op_autotune.py
Normal file
572
test/inductor/test_custom_op_autotune.py
Normal file
@ -0,0 +1,572 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
"""
|
||||
Tests for custom operation autotuning with PyTorch Inductor.
|
||||
|
||||
Users can register custom ops with multiple decomposition implementations and let
|
||||
Inductor automatically select the best performing variant. Key features tested:
|
||||
|
||||
- Name-based input generators (use argument names instead of indices)
|
||||
- Dynamic shape handling across multiple compilations
|
||||
- Parametric tuning with tuning_knob for combinatorial parameter exploration
|
||||
- Numerical correctness and performance validation
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.kernel.custom_op import (
|
||||
CustomOpConfig,
|
||||
register_custom_op_autotuning,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
class TestCustomOpAutoTune(TestCase):
|
||||
"""Test custom operation autotuning functionality."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test environment with appropriate device and dtype."""
|
||||
super().setUp()
|
||||
self.device = "cuda" if HAS_GPU else "cpu"
|
||||
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
||||
|
||||
def _create_test_configs(self):
|
||||
"""Create common test configurations for different sizes."""
|
||||
return [
|
||||
{"batch_size": 1, "seq_len": 32, "hidden_dim": 128},
|
||||
{"batch_size": 2, "seq_len": 64, "hidden_dim": 256},
|
||||
]
|
||||
|
||||
def _run_autotune_test(self, op_object, inputs, expected, test_name):
|
||||
"""Shared test infrastructure for autotuning tests."""
|
||||
|
||||
@torch.compile
|
||||
def test_model(*args):
|
||||
return op_object(*args)
|
||||
|
||||
torch._dynamo.reset()
|
||||
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
|
||||
|
||||
with config.patch(
|
||||
max_autotune=True,
|
||||
max_autotune_gemm_backends=autotune_backends,
|
||||
fx_graph_cache=False,
|
||||
benchmark_kernel=True,
|
||||
):
|
||||
compiled_result = test_model(*inputs)
|
||||
|
||||
self.assertEqual(
|
||||
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
compiled_result,
|
||||
expected,
|
||||
rtol=2e-1,
|
||||
atol=5e-1,
|
||||
msg=f"{test_name} numerical mismatch",
|
||||
)
|
||||
|
||||
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
|
||||
"""Utility to assert that all implementations produce equivalent results."""
|
||||
implementations = [(func.__name__, func) for func in decompositions]
|
||||
results = {}
|
||||
for name, impl in implementations:
|
||||
result = impl(*inputs)
|
||||
results[name] = result
|
||||
|
||||
# Basic sanity checks
|
||||
self.assertTrue(
|
||||
torch.isfinite(result).all(),
|
||||
f"{op_name} {name} produced non-finite values",
|
||||
)
|
||||
|
||||
# Verify numerical equivalence
|
||||
reference_name, reference_result = next(iter(results.items()))
|
||||
for name, result in results.items():
|
||||
if name != reference_name:
|
||||
rtol = 1e-1 if "Approximated" in name else 1e-2
|
||||
atol = 1e-1 if "Approximated" in name else 1e-2
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
reference_result,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"{op_name} {name} differs from {reference_name}",
|
||||
)
|
||||
|
||||
def _create_rmsnorm_inputs(self, batch_size=8, seq_len=1024, hidden_dim=512):
|
||||
"""Create test inputs for RMSNorm operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
return input_tensor, weight
|
||||
|
||||
def _create_mlp_inputs(
|
||||
self,
|
||||
batch_size=2,
|
||||
seq_len=32,
|
||||
hidden_dim=512,
|
||||
intermediate_dim=1024,
|
||||
output_dim=256,
|
||||
):
|
||||
"""Create test inputs for MLP operations."""
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
gate_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
up_weight = torch.randn(
|
||||
hidden_dim,
|
||||
intermediate_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
down_weight = torch.randn(
|
||||
intermediate_dim,
|
||||
output_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
return input_tensor, gate_weight, up_weight, down_weight
|
||||
|
||||
@skipIfXpu
|
||||
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
|
||||
"""Test RMSNorm autotuning decomposition variants compared to fallback default with dynamic shapes."""
|
||||
test_op_name = f"test_lib::rmsnorm_{id(self)}"
|
||||
|
||||
def rmsnorm_decomposition1(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""Variance-based approach: compute variance then rsqrt."""
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
return x * rstd * weight
|
||||
|
||||
def rmsnorm_decomposition2(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""vLLM-style RMSNorm implementation - variance computation first approach."""
|
||||
x_var = x # In vLLM, this could be sliced for variance_size_override
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
return x
|
||||
|
||||
def rmsnorm_decomposition3(
|
||||
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
"""vLLM-style RMSNorm with extended variance computation pattern."""
|
||||
x_squared = x.pow(2)
|
||||
variance = x_squared.mean(dim=-1, keepdim=True)
|
||||
|
||||
rstd = torch.rsqrt(variance + eps)
|
||||
normalized = x * rstd
|
||||
|
||||
# Apply weight scaling
|
||||
if weight is not None:
|
||||
normalized = normalized * weight
|
||||
return normalized
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_rmsnorm_op(
|
||||
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
return torch.nn.functional.rms_norm(
|
||||
input_tensor, input_tensor.shape[-1:], weight, eps=eps
|
||||
)
|
||||
|
||||
@test_rmsnorm_op.register_fake
|
||||
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
|
||||
return torch.empty_like(input_tensor)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
decompositions = [
|
||||
rmsnorm_decomposition1,
|
||||
rmsnorm_decomposition2,
|
||||
rmsnorm_decomposition3,
|
||||
]
|
||||
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(rmsnorm_decomposition1) for decomp in decompositions
|
||||
],
|
||||
name="test_rmsnorm_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
|
||||
"weight": lambda weight: torch.ones_like(weight, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
# Test multiple shapes to verify dynamic shape handling
|
||||
test_shapes = [(2, 16, 128), (8, 32, 256)]
|
||||
|
||||
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
|
||||
input_tensor = torch.randn(
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_dim,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torch.randn(
|
||||
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
||||
)
|
||||
|
||||
# Test numerical equivalence for all decompositions
|
||||
self._assert_implementations_equivalent(
|
||||
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
expected = rmsnorm_decomposition1(input_tensor, weight)
|
||||
self._run_autotune_test(
|
||||
op_object, (input_tensor, weight), expected, f"RMSNorm_{i}"
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
def test_mlp_custom_op_autotune(self):
|
||||
"""Test MLP autotuning with method parameter controlling different decomposition variants"""
|
||||
test_op_name = f"test_lib::mlp_{id(self)}"
|
||||
|
||||
def mlp_variants(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""MLP implementation with different computational approaches controlled by method parameter."""
|
||||
|
||||
if method == 0:
|
||||
# Separate matmuls: standard implementation with torch.matmul
|
||||
gate_proj = torch.matmul(input_tensor, gate_weight)
|
||||
up_proj = torch.matmul(input_tensor, up_weight)
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
elif method == 1:
|
||||
# Batched approach: uses torch.mm with reshaped tensors
|
||||
batch_shape = input_tensor.shape[:-1]
|
||||
hidden_dim = input_tensor.shape[-1]
|
||||
output_dim = down_weight.shape[-1]
|
||||
|
||||
input_2d = input_tensor.view(-1, hidden_dim)
|
||||
|
||||
gate_proj = torch.mm(input_2d, gate_weight)
|
||||
up_proj = torch.mm(input_2d, up_weight)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
output_2d = torch.mm(gated, down_weight)
|
||||
|
||||
return output_2d.view(*batch_shape, output_dim)
|
||||
|
||||
elif method == 2:
|
||||
# Fused weights approach: concatenate then split weights
|
||||
# Concatenate gate and up weights for one matrix multiply
|
||||
fused_weight = torch.cat([gate_weight, up_weight], dim=1)
|
||||
fused_proj = torch.matmul(input_tensor, fused_weight)
|
||||
|
||||
intermediate_dim = gate_weight.shape[1]
|
||||
gate_proj, up_proj = fused_proj.split(
|
||||
[intermediate_dim, intermediate_dim], dim=-1
|
||||
)
|
||||
|
||||
gated = torch.relu(gate_proj) * up_proj
|
||||
|
||||
return torch.matmul(gated, down_weight)
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_mlp_op(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=0
|
||||
)
|
||||
|
||||
@test_mlp_op.register_fake
|
||||
def _(
|
||||
input_tensor: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
method: int = 0,
|
||||
):
|
||||
return torch.empty(
|
||||
input_tensor.shape[:-1] + (down_weight.shape[-1],),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Use explicit configs with method parameter as tuning knob
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(mlp_variants, method=1), # Batched approach
|
||||
CustomOpConfig(mlp_variants, method=2), # Fused weights
|
||||
],
|
||||
name="test_mlp_autotuned",
|
||||
input_gen_fns={
|
||||
"input_tensor": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1,
|
||||
"gate_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"up_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
"down_weight": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs using the original helper method
|
||||
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
|
||||
|
||||
# Test that all method variants produce numerically equivalent results
|
||||
expected = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=0
|
||||
)
|
||||
|
||||
for method in [1, 2]:
|
||||
result = mlp_variants(
|
||||
input_tensor, gate_weight, up_weight, down_weight, method=method
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Method {method} not equivalent to method 0",
|
||||
)
|
||||
|
||||
# Test autotuning - all should be mathematically equivalent
|
||||
self._run_autotune_test(
|
||||
op_object,
|
||||
(input_tensor, gate_weight, up_weight, down_weight),
|
||||
expected,
|
||||
"MLP",
|
||||
)
|
||||
|
||||
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
|
||||
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
|
||||
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
|
||||
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
|
||||
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
|
||||
return a, b
|
||||
|
||||
@skipIfXpu
|
||||
def test_decompose_k_custom_op_autotune(self):
|
||||
"""Test decompose_k autotuning with parameter tuning for k_splits values."""
|
||||
test_op_name = f"test_lib::decompose_k_{id(self)}"
|
||||
|
||||
def decompose_k_implementation(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
"""Matrix multiply with k-way decomposition - parameter-tuned implementation."""
|
||||
m = a.shape[0]
|
||||
n = b.shape[1]
|
||||
k = a.shape[1]
|
||||
|
||||
k_parts = k // k_splits
|
||||
B = k_splits
|
||||
|
||||
a_reshaped = torch.permute(
|
||||
a.reshape(m, B, k_parts), (1, 0, 2)
|
||||
) # [B, m, k_parts]
|
||||
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
|
||||
|
||||
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
|
||||
|
||||
return torch.sum(result, dim=0) # [m, n]
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def test_decompose_k_op(
|
||||
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
||||
) -> torch.Tensor:
|
||||
return decompose_k_implementation(a, b, k_splits)
|
||||
|
||||
@test_decompose_k_op.register_fake
|
||||
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
|
||||
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
|
||||
|
||||
lib_name, op_name = test_op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_name)
|
||||
|
||||
# Use parameter tuning to test different k_splits values
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=2),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=32),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=64),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=128),
|
||||
CustomOpConfig(decompose_k_implementation, k_splits=256),
|
||||
],
|
||||
name="test_decompose_k_autotuned",
|
||||
input_gen_fns={
|
||||
"a": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix A
|
||||
"b": lambda fake_tensor: torch.randn_like(
|
||||
fake_tensor, device=self.device
|
||||
)
|
||||
* 0.1, # Matrix B
|
||||
},
|
||||
)
|
||||
|
||||
a, b = self._create_decompose_k_inputs()
|
||||
expected = a @ b
|
||||
self._run_autotune_test(op_object, (a, b), expected, "DecomposeK")
|
||||
|
||||
@skipIfXpu
|
||||
def test_multi_parameter_tuning(self):
|
||||
"""Test autotuning with multiple parameters using scale_mode and chunk_size."""
|
||||
op_name = f"test_lib::multi_param_{id(self)}"
|
||||
|
||||
def multi_param_scaling(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
"""Different scaling approaches controlled by scale_mode parameter."""
|
||||
if scale_mode == 1:
|
||||
# Simple broadcasting
|
||||
return x * factor
|
||||
elif scale_mode == 2:
|
||||
# Process in chunks
|
||||
batch_size, seq_len = x.shape[:2]
|
||||
chunks = []
|
||||
for start in range(0, seq_len, chunk_size):
|
||||
end = min(start + chunk_size, seq_len)
|
||||
chunk = x[:, start:end]
|
||||
chunks.append(chunk * factor)
|
||||
return torch.cat(chunks, dim=1)
|
||||
elif scale_mode == 3:
|
||||
# Using einsum for scaling
|
||||
return torch.einsum("...i,i->...i", x, factor)
|
||||
|
||||
@torch.library.custom_op(op_name, mutates_args=())
|
||||
def multi_param_op(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
return multi_param_scaling(x, factor, scale_mode, chunk_size)
|
||||
|
||||
@multi_param_op.register_fake
|
||||
def _(
|
||||
x: torch.Tensor,
|
||||
factor: torch.Tensor,
|
||||
scale_mode: int = 1,
|
||||
chunk_size: int = 16,
|
||||
):
|
||||
return torch.empty_like(x)
|
||||
|
||||
lib_name, op_suffix = op_name.split("::")
|
||||
op_object = getattr(getattr(torch.ops, lib_name), op_suffix)
|
||||
|
||||
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
|
||||
register_custom_op_autotuning(
|
||||
op_object.default,
|
||||
configs=[
|
||||
CustomOpConfig(multi_param_scaling, scale_mode=1), # Broadcast
|
||||
CustomOpConfig(
|
||||
multi_param_scaling, scale_mode=2, chunk_size=16
|
||||
), # Chunked 16
|
||||
CustomOpConfig(
|
||||
multi_param_scaling, scale_mode=2, chunk_size=32
|
||||
), # Chunked 32
|
||||
CustomOpConfig(multi_param_scaling, scale_mode=3), # Einsum
|
||||
],
|
||||
name="multi_param_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
||||
"factor": lambda t: torch.ones(
|
||||
t.shape[-1], device=self.device, dtype=t.dtype
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Create test inputs
|
||||
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
|
||||
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
|
||||
|
||||
# Verify numerical equivalence across all approaches
|
||||
expected_result = test_x * test_factor
|
||||
|
||||
# Test each scale_mode variant
|
||||
configs = [
|
||||
(1, 16), # broadcast, chunk_size ignored
|
||||
(2, 16), # chunked with size 16
|
||||
(2, 32), # chunked with size 32
|
||||
(3, 16), # einsum, chunk_size ignored
|
||||
]
|
||||
|
||||
for i, (scale_mode, chunk_size) in enumerate(configs):
|
||||
result = multi_param_scaling(
|
||||
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected_result,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
|
||||
)
|
||||
|
||||
# Test autotuning
|
||||
self._run_autotune_test(
|
||||
op_object, (test_x, test_factor), expected_result, "MultiParam"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
@ -8,6 +8,7 @@ from torch._inductor import ir
|
||||
from torch._inductor.codegen.common import KernelTemplate
|
||||
from torch._inductor.ir import (
|
||||
Buffer,
|
||||
FixedLayout,
|
||||
get_free_symbols,
|
||||
get_symbolic_inputs,
|
||||
gm_original_output_strides,
|
||||
@ -120,7 +121,12 @@ class SubgraphChoiceCaller(ir.ChoiceCaller):
|
||||
bm_func([*sym_inputs, *args])
|
||||
if config.profile_bandwidth_with_do_bench_using_profiling:
|
||||
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
# Use appropriate benchmarker based on layout device type
|
||||
if self.layout.device.type == "cpu":
|
||||
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
else:
|
||||
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
|
||||
|
||||
def hash_key(self) -> str:
|
||||
return "-".join(
|
||||
@ -207,3 +213,152 @@ class SubgraphTemplate(KernelTemplate):
|
||||
description=description,
|
||||
make_fx_graph=make_fx_graph,
|
||||
)
|
||||
|
||||
def generate_custom_op_choices(
|
||||
self,
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
input_nodes: list[Buffer],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> list[SubgraphChoiceCaller]:
|
||||
"""
|
||||
Generate multiple SubgraphChoiceCaller instances for custom op autotuning.
|
||||
|
||||
This method extends SubgraphTemplate to support custom op decompositions,
|
||||
allowing multiple implementations to compete in autotuning.
|
||||
|
||||
Args:
|
||||
name: Base name for the choices
|
||||
decompositions: List of decomposition functions to compare
|
||||
input_nodes: Input nodes for the operation
|
||||
kwargs: Additional arguments for decomposition functions
|
||||
default_impl: Default implementation for layout inference
|
||||
|
||||
Returns:
|
||||
List of SubgraphChoiceCaller instances for autotuning
|
||||
"""
|
||||
if not decompositions:
|
||||
return []
|
||||
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Infer layouts and ensure stride consistency for fair autotuning comparison
|
||||
layouts = [
|
||||
self._infer_custom_op_layout(input_nodes, [decomp], kwargs, default_impl)
|
||||
for decomp in decompositions
|
||||
]
|
||||
|
||||
self._validate_stride_consistency(name, decompositions, layouts)
|
||||
|
||||
# Assert single output layout - assumes custom ops have one output tensor
|
||||
assert len(layouts) > 0, f"No layouts inferred for custom op '{name}'"
|
||||
assert all(
|
||||
layout.device == layouts[0].device
|
||||
and layout.dtype == layouts[0].dtype
|
||||
and layout.size == layouts[0].size
|
||||
for layout in layouts
|
||||
), f"All decompositions for '{name}' must produce equivalent output layouts"
|
||||
|
||||
layout = layouts[0] # All layouts have equivalent stride/shape/dtype now
|
||||
|
||||
choices = []
|
||||
for decomp in decompositions:
|
||||
# Create make_fx_graph function for this decomposition
|
||||
def make_fx_graph(*args: Any, decomp: Callable[..., Any] = decomp) -> Any:
|
||||
import functools
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
# Ensure kwargs is not None for unpacking
|
||||
decomp_kwargs = kwargs if kwargs is not None else {}
|
||||
return make_fx(functools.partial(decomp, **decomp_kwargs))(*args)
|
||||
|
||||
choice = self.generate(
|
||||
name=f"{name}_{decomp.__name__}",
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_fx_graph=make_fx_graph,
|
||||
description=f"CustomOp {decomp.__name__}",
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return choices
|
||||
|
||||
def _validate_stride_consistency(
|
||||
self,
|
||||
op_name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
layouts: list[Layout],
|
||||
) -> None:
|
||||
"""Ensure all decompositions produce compatible strides for fair autotuning."""
|
||||
if not layouts:
|
||||
return
|
||||
|
||||
strides = [layout.stride for layout in layouts]
|
||||
reference = strides[0]
|
||||
for i, stride in enumerate(strides[1:]):
|
||||
if stride != reference:
|
||||
raise AssertionError(
|
||||
f"Stride mismatch in custom op '{op_name}' autotuning: "
|
||||
f"'{decompositions[i].__name__}' produces stride {stride}, "
|
||||
f"but '{decompositions[0].__name__}' produces {reference}. "
|
||||
f"All decompositions must have identical output strides."
|
||||
)
|
||||
|
||||
def _infer_custom_op_layout(
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
decompositions: list[Callable[..., Any]],
|
||||
kwargs: dict[str, Any],
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
) -> Layout:
|
||||
"""Infer output layout for custom ops using the default implementation when available.
|
||||
|
||||
Note that the Subgraph assumes custom ops return exactly one tensor so far.
|
||||
TODO: Add support for multiple output custom ops.
|
||||
"""
|
||||
import functools
|
||||
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
# Assert kwargs contain only non-tensor arguments for functools.partial
|
||||
for key, value in kwargs.items():
|
||||
assert not isinstance(value, (torch.Tensor, Buffer)), (
|
||||
f"kwargs['{key}'] contains tensor {type(value)}. "
|
||||
f"Tensor arguments should be in input_nodes, not kwargs. "
|
||||
f"Only scalar/non-tensor parameters should be in kwargs."
|
||||
)
|
||||
|
||||
# Use default_impl if available, otherwise use first decomposition
|
||||
impl = default_impl if default_impl is not None else decompositions[0]
|
||||
|
||||
with V.fake_mode:
|
||||
example_inputs = []
|
||||
for inp in input_nodes:
|
||||
raw_shape = inp.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=inp.get_dtype(), device=inp.get_device()
|
||||
)
|
||||
example_inputs.append(fake_tensor)
|
||||
|
||||
fn = functools.partial(
|
||||
impl, **kwargs
|
||||
) # kwargs must be non-tensor for partial
|
||||
output = fn(*example_inputs)
|
||||
|
||||
# Assert single output
|
||||
assert isinstance(output, torch.Tensor), (
|
||||
f"Expected single tensor output, got {type(output)}. "
|
||||
f"Multi-output custom ops not yet supported in autotuning."
|
||||
)
|
||||
|
||||
return FixedLayout(
|
||||
device=output.device,
|
||||
dtype=output.dtype,
|
||||
size=output.shape,
|
||||
stride=output.stride(),
|
||||
)
|
||||
|
397
torch/_inductor/kernel/custom_op.py
Normal file
397
torch/_inductor/kernel/custom_op.py
Normal file
@ -0,0 +1,397 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||
from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox
|
||||
from torch._inductor.lowering import lowerings, validate_ir
|
||||
from torch._inductor.select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ExternKernelChoice,
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
|
||||
|
||||
class CustomOpConfig:
|
||||
"""Config for custom op autotuning - similar to triton.Config.
|
||||
|
||||
Specifies decomposition function with parameter values.
|
||||
Each config creates exactly one variant (no Cartesian product).
|
||||
|
||||
Args:
|
||||
decomposition: Function to autotune
|
||||
**params: Parameters passed to the function
|
||||
|
||||
Examples:
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
|
||||
CustomOpConfig(fallback_impl)
|
||||
"""
|
||||
|
||||
def __init__(self, decomposition: Callable[..., Any], **params: Any):
|
||||
if not callable(decomposition):
|
||||
raise TypeError(
|
||||
f"decomposition must be callable, got {type(decomposition)}"
|
||||
)
|
||||
|
||||
self.decomposition = decomposition
|
||||
self.params = params
|
||||
|
||||
# Generate descriptive name
|
||||
if self.params:
|
||||
param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(self.params.items()))
|
||||
self.name = f"{decomposition.__name__}_{param_suffix}"
|
||||
else:
|
||||
self.name = decomposition.__name__
|
||||
|
||||
def create_variant(self) -> Callable[..., Any]:
|
||||
"""Create callable with parameters pre-applied using functools.partial."""
|
||||
if self.params:
|
||||
variant = functools.partial(self.decomposition, **self.params)
|
||||
variant.__name__ = self.name # type: ignore[attr-defined]
|
||||
return variant
|
||||
|
||||
return self.decomposition
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.params:
|
||||
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
|
||||
return f"CustomOpConfig({self.decomposition.__name__}, {params_str})"
|
||||
return f"CustomOpConfig({self.decomposition.__name__})"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"autotune_custom_op",
|
||||
"register_custom_op_autotuning",
|
||||
"CustomOpConfig",
|
||||
]
|
||||
|
||||
|
||||
def _extract_tensor_inputs(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Extract tensor inputs from mixed args/kwargs.
|
||||
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
|
||||
Non-tensor kwargs are later functools.partial'd into decomposition functions.
|
||||
|
||||
Args:
|
||||
args: Positional arguments (mix of tensors and scalars)
|
||||
kwargs: Keyword arguments (mix of tensors and scalars)
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor_inputs_list, non_tensor_kwargs)
|
||||
"""
|
||||
tensor_inputs = []
|
||||
non_tensor_kwargs = {}
|
||||
|
||||
# Process args and kwargs: separate tensor inputs and non tensor args
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(arg)
|
||||
else:
|
||||
# Add non-tensor positional args to kwargs with generated names
|
||||
non_tensor_kwargs[f"arg_{i}"] = arg
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, (TensorBox, Buffer)):
|
||||
tensor_inputs.append(value)
|
||||
else:
|
||||
non_tensor_kwargs[key] = value
|
||||
|
||||
return tensor_inputs, non_tensor_kwargs
|
||||
|
||||
|
||||
def _create_user_input_gen_fns(
|
||||
inputs: list[Any],
|
||||
arg_names: list[str],
|
||||
user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]],
|
||||
) -> dict[int, Callable[[Any], torch.Tensor]]:
|
||||
"""Convert user input generators from name-based to index-based format.
|
||||
Inductor autotune's input_gen_fns expects index of arg_names as key.
|
||||
|
||||
Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes.
|
||||
"""
|
||||
from torch._inductor import config
|
||||
|
||||
name_to_index = {name: i for i, name in enumerate(arg_names)}
|
||||
index_based_fns = {}
|
||||
|
||||
for name, gen_fn in user_input_gen_fns.items():
|
||||
if name in name_to_index:
|
||||
index_based_fns[name_to_index[name]] = gen_fn
|
||||
else:
|
||||
print(f"Warning: Unknown argument name '{name}' in input_gen_fns")
|
||||
|
||||
def create_internal_input_gen_fn(
|
||||
user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str
|
||||
) -> Callable[[Any], torch.Tensor]:
|
||||
"""Create internal input generator that converts IR buffer to user's fake tensor."""
|
||||
|
||||
def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor:
|
||||
raw_shape = ir_buffer.get_size()
|
||||
concrete_shape = V.graph.sizevars.size_hints(
|
||||
raw_shape, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
fake_tensor = torch.empty(
|
||||
concrete_shape, dtype=ir_buffer.get_dtype(), device="meta"
|
||||
)
|
||||
return user_function(fake_tensor)
|
||||
|
||||
return internal_input_gen_fn
|
||||
|
||||
return {
|
||||
i: create_internal_input_gen_fn(
|
||||
user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}"
|
||||
)
|
||||
for i, user_gen_fn in index_based_fns.items()
|
||||
if i < len(inputs)
|
||||
}
|
||||
|
||||
|
||||
def _create_fallback_choice(
|
||||
name: str,
|
||||
default_impl: Callable[..., Any],
|
||||
fake_output: torch.Tensor,
|
||||
kwargs: dict[str, Any],
|
||||
) -> ExternKernelChoice:
|
||||
"""Create fallback choice for default implementation."""
|
||||
|
||||
def fallback_wrapper(*args: Any) -> Any:
|
||||
return default_impl(*args, **kwargs)
|
||||
|
||||
return ExternKernelChoice(
|
||||
kernel=fallback_wrapper,
|
||||
name=f"{name}_fallback_default",
|
||||
has_out_variant=False,
|
||||
op_overload=default_impl,
|
||||
use_fallback_kernel=True,
|
||||
)
|
||||
|
||||
|
||||
def _create_parameter_variants(
|
||||
decompositions: list[Callable[..., Any]],
|
||||
tuning_knob: dict[str, list[Any]],
|
||||
) -> list[Any]: # Returns partial objects which are callable
|
||||
"""Create parameter variants for decompositions using tuning knob.
|
||||
|
||||
Args:
|
||||
decompositions: Base implementation functions
|
||||
tuning_knob: Parameter tuning dict with parameter names and value lists
|
||||
|
||||
Returns:
|
||||
List of variant functions with all parameter combinations
|
||||
"""
|
||||
# Validate parameter values
|
||||
for param_name, param_values in tuning_knob.items():
|
||||
if not param_values or not isinstance(param_values, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"Parameter values for '{param_name}' must be a list or tuple, got {type(param_values)}"
|
||||
)
|
||||
|
||||
# Generate all combinations of parameter values using Cartesian product
|
||||
import itertools
|
||||
|
||||
param_names = list(tuning_knob.keys())
|
||||
param_values_lists = list(tuning_knob.values())
|
||||
param_combinations = list(itertools.product(*param_values_lists))
|
||||
|
||||
# Create variants for each decomposition with each parameter combination
|
||||
variants = []
|
||||
for decomp_fn in decompositions:
|
||||
for param_combo in param_combinations:
|
||||
# Create kwargs dict for this combination
|
||||
param_kwargs = dict(zip(param_names, param_combo))
|
||||
|
||||
# Create partial function with all parameters
|
||||
variant = functools.partial(decomp_fn, **param_kwargs)
|
||||
param_suffix = "_".join(
|
||||
f"{name}_{value}" for name, value in param_kwargs.items()
|
||||
)
|
||||
variant.__name__ = f"{decomp_fn.__name__}_{param_suffix}" # type: ignore[attr-defined]
|
||||
variants.append(variant)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
def autotune_custom_op(
|
||||
name: str,
|
||||
decompositions: list[Callable[..., Any]],
|
||||
inputs: list[Any],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
default_impl: Optional[Callable[..., Any]] = None,
|
||||
user_input_gen_fns: Optional[
|
||||
dict[str, Callable[[torch.Tensor], torch.Tensor]]
|
||||
] = None,
|
||||
) -> Union[TensorBox, Any]:
|
||||
"""Autotune custom operations by comparing multiple decomposition implementations.
|
||||
|
||||
Currently supports SINGLE OUTPUT custom ops only.
|
||||
TODO: Add support for multiple output custom ops (tuple/list returns).
|
||||
|
||||
This function generates multiple implementation choices for a custom operation and
|
||||
uses Inductor's autotuning system to select the best performing variant at runtime.
|
||||
|
||||
Args:
|
||||
name: Unique identifier for the autotuning operation
|
||||
decompositions: List of alternative implementation functions to benchmark
|
||||
inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects)
|
||||
kwargs: Non-tensor parameters to pass to decomposition functions
|
||||
default_impl: Original custom op implementation used as fallback
|
||||
user_input_gen_fns: Optional custom input generators for benchmarking.
|
||||
Maps input indices to functions that take fake tensors
|
||||
and return real tensors for performance measurement.
|
||||
|
||||
Returns:
|
||||
IR node representing the optimized operation result
|
||||
|
||||
Raises:
|
||||
TypeError: If decompositions is not a list/tuple
|
||||
RuntimeError: If no inputs or no valid choices generated
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if not isinstance(decompositions, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"decompositions must be a list or tuple of callables, got {type(decompositions)}"
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning")
|
||||
|
||||
template = SubgraphTemplate(name=name)
|
||||
choices = template.generate_custom_op_choices(
|
||||
name=name,
|
||||
decompositions=list(decompositions),
|
||||
input_nodes=list(inputs),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Add default implementation as fallback
|
||||
if default_impl and hasattr(default_impl, "_op"):
|
||||
fallback_name = f"{name}_fallback_default"
|
||||
from torch._inductor.select_algorithm import extern_kernels
|
||||
|
||||
# Skip if extern_kernel already registered to avoid duplicate registration error
|
||||
if not hasattr(extern_kernels, fallback_name):
|
||||
with V.fake_mode:
|
||||
fake_inputs = [ir_node_to_tensor(inp) for inp in inputs]
|
||||
fake_output = default_impl(*fake_inputs, **kwargs)
|
||||
|
||||
fallback_choice = _create_fallback_choice(
|
||||
name, default_impl, fake_output, kwargs
|
||||
)
|
||||
fallback_choice.maybe_append_choice(
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=FixedLayout(
|
||||
device=fake_output.device,
|
||||
dtype=fake_output.dtype,
|
||||
size=fake_output.shape,
|
||||
stride=fake_output.stride(),
|
||||
),
|
||||
)
|
||||
|
||||
if not choices:
|
||||
raise RuntimeError(f"No valid choices generated for {name}")
|
||||
|
||||
# Convert user input generation functions to internal format
|
||||
input_gen_fns = {}
|
||||
if user_input_gen_fns:
|
||||
import inspect
|
||||
|
||||
arg_names = (
|
||||
list(inspect.signature(decompositions[0]).parameters.keys())
|
||||
if decompositions
|
||||
else []
|
||||
)
|
||||
input_gen_fns = _create_user_input_gen_fns(
|
||||
inputs, arg_names, user_input_gen_fns
|
||||
)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name=name,
|
||||
choices=choices,
|
||||
input_nodes=list(inputs),
|
||||
layout=choices[0].layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._ops.OpOverload,
|
||||
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
|
||||
name: Optional[str] = None,
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None,
|
||||
) -> None:
|
||||
"""Register custom op for autotuning with explicit configs.
|
||||
|
||||
Uses config-based API where each config specifies a decomposition function
|
||||
with its parameter values.
|
||||
|
||||
Args:
|
||||
custom_op: Custom operation to register
|
||||
configs: List of CustomOpConfig objects or callable functions
|
||||
name: Operation name (default: "{op_name}_autotuned")
|
||||
input_gen_fns: Custom input generators for benchmarking
|
||||
|
||||
Examples:
|
||||
register_custom_op_autotuning(
|
||||
torch.ops.mylib.attention.default,
|
||||
configs=[
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
|
||||
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
|
||||
CustomOpConfig(fallback_impl), # No params
|
||||
],
|
||||
input_gen_fns={
|
||||
"query": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"key": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
"value": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
}
|
||||
)
|
||||
"""
|
||||
if not isinstance(configs, (list, tuple)):
|
||||
raise TypeError(f"configs must be a list or tuple, got {type(configs)}")
|
||||
|
||||
if not configs:
|
||||
raise ValueError("At least one config must be provided")
|
||||
|
||||
# Convert configs to decomposition functions
|
||||
final_decompositions = []
|
||||
for config in configs:
|
||||
if isinstance(config, CustomOpConfig):
|
||||
# CustomOpConfig object
|
||||
final_decompositions.append(config.create_variant())
|
||||
elif callable(config):
|
||||
# Direct callable function
|
||||
final_decompositions.append(config)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Each config must be a CustomOpConfig object or callable function, "
|
||||
f"got {type(config)}"
|
||||
)
|
||||
|
||||
if name is None:
|
||||
name = f"{custom_op._name}_autotuned"
|
||||
|
||||
@functools.wraps(custom_op)
|
||||
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
|
||||
# Extract tensor inputs and non-tensor parameters
|
||||
tensor_inputs, non_tensor_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=final_decompositions,
|
||||
inputs=tensor_inputs,
|
||||
kwargs=non_tensor_kwargs,
|
||||
default_impl=custom_op,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
|
||||
lowerings[custom_op] = autotuning_lowering
|
@ -2879,7 +2879,6 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
|
||||
# if timings is empty, we really have no choice but to return a semi-random
|
||||
# choice. returning the first `ExternKernelCaller` is probably the safest bet
|
||||
# in this case, since it will generally be the ATen kernel. if there are no
|
||||
@ -3483,6 +3482,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes])
|
||||
if config.autotune_num_choices_displayed == 0:
|
||||
return
|
||||
|
||||
# when autotune_num_choices_displayed is None, [:None] means all
|
||||
n = config.autotune_num_choices_displayed
|
||||
top_k = sorted(timings, key=timings.__getitem__)[:n]
|
||||
|
Reference in New Issue
Block a user