Compare commits

...

3 Commits

Author SHA1 Message Date
fa9d5c2dd7 Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D86474839](https://our.internmc.facebook.com/intern/diff/D86474839)

[ghstack-poisoned]
2025-11-10 17:28:12 -08:00
f048cb1f3c Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-11-06 16:29:43 -08:00
c277e07f77 conv: refactor for lookup table support
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

[ghstack-poisoned]
2025-11-05 18:57:57 -08:00
5 changed files with 610 additions and 88 deletions

View File

@ -2,14 +2,18 @@
import re
import unittest
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Optional
from unittest.mock import patch
import torch
import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.choices import InductorChoices
from torch._inductor.kernel_inputs import MMKernelInputs
from torch._inductor.kernel_inputs import (
ConvKernelInputs,
MMKernelInputs,
SerializableValue,
)
from torch._inductor.lookup_table.choices import LookupTableChoices
from torch._inductor.select_algorithm import (
add_preprocessing_fn,
@ -54,7 +58,7 @@ class MockMMKernelInputs(MMKernelInputs):
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
):
@ -80,6 +84,37 @@ class MockMMKernelInputs(MMKernelInputs):
return self.tensors[0].device.type
class MockConvKernelInputs(ConvKernelInputs):
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, SerializableValue]] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""Initialize with real tensors, creating mock nodes for the base class"""
mock_nodes = [MockTensorNode(t) for t in tensors]
super().__init__(
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
)
self.tensors = tensors # Keep reference to original tensors
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int shapes"""
return self.shapes_symbolic()
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int strides"""
return self.strides_symbolic() # pyre-ignore
@property
def device_type(self) -> Optional[str]:
return self.tensors[0].device.type
class BaseLookupTableTest(TestCase):
"""Base class for lookup table tests with common setup and utilities"""
@ -103,7 +138,7 @@ class BaseLookupTableTest(TestCase):
shapes: Optional[list[tuple[int, ...]]] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float32,
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
) -> MockMMKernelInputs:
"""Create MockMMKernelInputs with real tensors"""
if shapes is None:
@ -1055,6 +1090,119 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
self.run_model("mm", tensors)
@fresh_cache()
def test_conv2d_lookup_table_entry_e2e(self):
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
import torch._inductor.kernel.conv
# Create input tensors with specific shapes for conv2d
# Input: [batch=2, in_channels=3, height=32, width=32]
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
# Make them channels-last to match what conv lowering uses
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
# Define conv parameters - use these SAME values everywhere
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
mock_scalars = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": False,
"output_padding": (0, 0),
"groups": groups,
}
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
# Create lookup key for "convolution" operation
choices_handler = LookupTableChoices()
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
# Get the exact template UID from conv2d_template
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
# Create a precisely configured conv2d config
# IMPORTANT: Only include per-config tunable parameters!
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
conv2d_config = {
"template_id": template_uid,
# Per-config tunable parameters only (what you'd tune via autotuning)
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 32,
"num_stages": 2,
"num_warps": 4,
}
# Setup lookup table
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
def validate_conv_choice(choices):
assert len(choices) == 1, (
f"Expected 1 choice from lookup table, got {len(choices)}"
)
assert isinstance(choices[0], TritonTemplateCaller), (
f"Expected TritonTemplateCaller, got {type(choices[0])}"
)
assert "convolution2d" in choices[0].name, (
f"Expected 'convolution2d' in name, got {choices[0].name}"
)
return choices
add_preprocessing_fn(validate_conv_choice)
# Create and compile the model using the SAME weight tensor
class SimpleConv2d(nn.Module):
def __init__(self, weight):
super().__init__()
self.register_buffer("weight", weight)
def forward(self, x):
return torch.conv2d(
x,
self.weight,
bias=None,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
model = SimpleConv2d(weight).to(self.device)
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
compiled_model = torch.compile(model)
result = compiled_model(x) # Use the SAME x tensor
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
# (same spatial dims due to padding=1, stride=1, kernel=3)
expected_shape = (2, 64, 32, 32)
self.assertEqual(
result.shape,
expected_shape,
f"Expected shape {expected_shape}, got {result.shape}",
)
self.assertFalse(
torch.isnan(result).any().item(),
"Output contains NaN values",
)
self.assertFalse(
torch.isinf(result).any().item(),
"Output contains Inf values",
)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -8,6 +8,7 @@ import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
from .. import config, ir
from ..kernel_inputs import ConvKernelInputs
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -16,7 +17,9 @@ from ..lowering import (
)
from ..select_algorithm import (
autotune_select_algorithm,
ChoiceCaller,
ExternKernelChoice,
KernelTemplate,
SymbolicGridFn,
TritonTemplate,
)
@ -76,7 +79,7 @@ LOOP_BODY_2D = """
& (idx_x_h < IN_H)[:, None]
& (idx_x_w >= 0)[:, None]
& (idx_x_w < IN_W)[:, None]
& (idx_x_c < GROUP_IN_C)[None, :]
& (idx_x_c < GROUP_IN_C)[None, :
)
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
@ -542,34 +545,40 @@ def convolution(
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
ordered_kwargs_for_cpp_kernel = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
if bias is None:
args = [x, weight]
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
else:
args = [x, weight, bias]
# Create ConvKernelInputs for unified template configuration
# Only include bias in input_nodes when it's not None
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
if bias is not None:
bias.realize()
bias.freeze_layout()
V.graph.sizevars.guard_int_seq(bias.get_size())
input_nodes = [x, weight, bias]
bias_idx = 2
else:
input_nodes = [x, weight]
bias_idx = None
kernel_inputs = ConvKernelInputs(
input_nodes,
scalars={
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
},
x_idx=0,
weight_idx=1,
bias_idx=bias_idx,
)
# Build list of templates to try
templates: list[ExternKernelChoice | KernelTemplate] = []
choices = []
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
choices = [
aten_convolution.bind(
args,
layout,
ordered_kwargs_for_cpp_kernel,
**kwargs,
)
]
templates.append(aten_convolution)
if (
torch._inductor.utils._use_conv_autotune_backend("TRITON")
@ -587,60 +596,23 @@ def convolution(
and is_zeros(padding)
and groups == 1
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
templates.append(aten_conv1x1_via_mm)
conv_configs = V.choices.get_conv_configs(device_type)
# Add appropriate template based on ndim
if ndim == 2:
templates.append(conv2d_template)
elif ndim == 3:
templates.append(conv3d_template)
dtype_size = x.get_dtype().itemsize
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
dtype_size=dtype_size,
):
if ndim == 2:
conv2d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_H=kernel_shape[0],
KERNEL_W=kernel_shape[1],
STRIDE_H=stride[0],
STRIDE_W=stride[1],
PADDING_H=padding[0],
PADDING_W=padding[1],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
elif ndim == 3:
conv3d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_D=kernel_shape[0],
KERNEL_H=kernel_shape[1],
KERNEL_W=kernel_shape[2],
STRIDE_D=stride[0],
STRIDE_H=stride[1],
STRIDE_W=stride[2],
PADDING_D=padding[0],
PADDING_H=padding[1],
PADDING_W=padding[2],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
# Initialize choices list and extend with template configs
choices: list[ChoiceCaller] = []
choices.extend(
V.choices.get_template_configs(
kernel_inputs,
templates,
"convolution",
)
)
if use_ck_conv_template(layout):
CKGroupedConvFwdTemplate.add_ck_conv_choices(
choices,
@ -652,7 +624,9 @@ def convolution(
groups=groups,
n_spatial_dimensions=ndim,
)
return autotune_select_algorithm("convolution", choices, args, layout)
return autotune_select_algorithm(
"convolution", choices, kernel_inputs.nodes(), layout
)
@register_lowering(aten._convolution)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
@ -12,10 +13,12 @@ from .ir import FixedLayout, FlexibleLayout, Layout
if TYPE_CHECKING:
from collections.abc import Sequence
import sympy
# Type aliases for serializable scalar values
Serializable = Union[int, float, bool]
SerializableValue = Union[Serializable, Sequence[Serializable]]
class KernelInputs(ABC):
"""
@ -27,7 +30,7 @@ class KernelInputs(ABC):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
):
"""
@ -183,7 +186,7 @@ class KernelInputs(ABC):
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]:
def get_scalar(self, name: str) -> SerializableValue:
"""
Get the scalar value for a given name.
@ -191,7 +194,7 @@ class KernelInputs(ABC):
name: Name of the scalar to get
Returns:
The scalar value
The scalar value (can be int, float, bool, or tuple of these types)
"""
assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name]
@ -216,7 +219,7 @@ class MMKernelInputs(KernelInputs):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
@ -336,3 +339,113 @@ class MMKernelInputs(KernelInputs):
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
return (m, n, k)
class ConvKernelInputs(KernelInputs):
"""
Specialized KernelInputs for convolution operations.
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
"""
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""
Initialize with convolution input nodes.
Args:
input_nodes: List containing [x, weight] or [x, weight, bias]
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
out_dtype: Optional output dtype
x_idx: Index of input tensor (default: 0)
weight_idx: Index of weight tensor (default: 1)
bias_idx: Index of bias tensor if present (default: None)
"""
super().__init__(input_nodes, scalars, out_dtype)
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
self._x_idx = x_idx
self._weight_idx = weight_idx
self._bias_idx = bias_idx
# Validate that required scalars are present
required_scalars = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
for key in required_scalars:
assert key in self._scalars, f"Conv requires scalar '{key}'"
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self._input_nodes[self._x_idx].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for convolution.
Args:
flexible: If True, return FlexibleLayout, otherwise FixedLayout
Returns:
Layout for the convolution output
"""
from torch._inductor.kernel.conv import conv_layout
x = self._input_nodes[self._x_idx]
weight = self._input_nodes[self._weight_idx]
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
# Use existing conv_layout function
# We know the types here because conv requires these specific scalar types
layout = conv_layout(
x,
weight,
bias,
self._scalars["stride"], # type: ignore[arg-type]
self._scalars["padding"], # type: ignore[arg-type]
self._scalars["dilation"], # type: ignore[arg-type]
self._scalars["transposed"], # type: ignore[arg-type]
self._scalars["output_padding"], # type: ignore[arg-type]
self._scalars["groups"], # type: ignore[arg-type]
)
# TODO: Handle flexible vs fixed based on config if needed
return layout
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
"""
Get x, weight, and optional bias nodes.
Returns:
Tuple of (x, weight, bias) where bias may be None
"""
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
def spatial_dims(self) -> tuple[Any, ...]:
"""
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
Returns:
Tuple of spatial dimension sizes
"""
x_shape = self._input_nodes[self._x_idx].get_size()
return x_shape[2:] # Skip batch and channel dims

View File

@ -1,6 +1,6 @@
# NOTE: add new template heuristics here, so they get imported and registered
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
from . import aten, base, contiguous_mm, decompose_k, registry, triton
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
# expose the entry function
from .registry import get_template_heuristic

View File

@ -0,0 +1,287 @@
from __future__ import annotations
from typing import Any, cast, TYPE_CHECKING
import torch
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
from ..kernel_inputs import ConvKernelInputs
from ..utils import is_ones, sympy_product
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
from .triton import (
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
if TYPE_CHECKING:
from collections.abc import Generator
from ..kernel_inputs import KernelInputs
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin for conv templates that converts config lists to template kwargs.
Similar to MMTemplateConfigMixin but for convolutions.
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
"""
# Type hint for methods from BaseConfigHeuristic
get_conv_configs: Any
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
Return template kwargs that don't change per-config.
These are derived from kernel_inputs and must include all template parameters.
Args:
kernel_inputs: ConvKernelInputs containing input tensors and conv params
op_name: Operation name (e.g., "convolution")
Returns:
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
ndim = len(kernel_shape)
# Extract scalars
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if we should unroll (only for 1x1 kernels)
unroll = is_ones(kernel_shape)
# Build kwargs dict based on ndim
kwargs: dict[str, Any] = {
"GROUPS": groups,
"UNROLL": unroll,
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
}
if ndim == 2:
kwargs.update(
{
"KERNEL_H": kernel_shape[0],
"KERNEL_W": kernel_shape[1],
"STRIDE_H": stride[0],
"STRIDE_W": stride[1],
"PADDING_H": padding[0],
"PADDING_W": padding[1],
}
)
elif ndim == 3:
kwargs.update(
{
"KERNEL_D": kernel_shape[0],
"KERNEL_H": kernel_shape[1],
"KERNEL_W": kernel_shape[2],
"STRIDE_D": stride[0],
"STRIDE_H": stride[1],
"STRIDE_W": stride[2],
"PADDING_D": padding[0],
"PADDING_H": padding[1],
"PADDING_W": padding[2],
}
)
return kwargs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
Args:
kernel_inputs: ConvKernelInputs containing input tensors
op_name: Operation name
Yields:
Dict of per-config kwargs for each configuration to try
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
"ConvTemplateConfigMixin requires ConvKernelInputs"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Calculate dimensions for heuristics
weight_size = weight.get_size()
out_chan = weight_size[0]
in_chan = weight_size[1]
# Batch * spatial dimensions product
x_size = x.get_size()
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
conv_configs_generator = self.get_conv_configs()
dtype_size = x.get_dtype().itemsize
# Generate configs (reusing mm preprocess_mm_configs machinery)
for c in conv_configs_generator(
batch_spatial_product,
out_chan,
in_chan,
dtype_size=dtype_size,
op_name="conv",
):
# Yield per-config kwargs
yield {
"BLOCK_M": c.kwargs.get("BLOCK_M"),
"BLOCK_N": c.kwargs.get("BLOCK_N"),
"BLOCK_K": c.kwargs.get("BLOCK_K"),
"num_stages": c.num_stages,
"num_warps": c.num_warps,
}
# ATEN convolution heuristic (no per-config tuning)
@register_template_heuristic(aten_convolution.uid, None)
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
"""
Pseudo heuristic for ATen convolution.
ATen doesn't have configs to tune - it's a single choice.
"""
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
# ATen doesn't have per-config kwargs to tune
yield dict()
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
"""
assert isinstance(kernel_inputs, ConvKernelInputs)
# Extract scalar values from kernel_inputs
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation"))
transposed = cast(bool, kernel_inputs.get_scalar("transposed"))
output_padding = cast(
tuple[int, ...], kernel_inputs.get_scalar("output_padding")
)
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if bias is None to match old behavior
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
x, weight, bias = kernel_inputs.get_x_weight_bias()
kwargs: dict[str, Any] = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
}
if bias is None:
# When bias is None, torch.convolution expects it as a kwarg
kwargs["bias"] = None
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"bias",
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
else:
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
return kwargs
# CUDA Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
"""Conv template heuristic for CUDA."""
# ROCm Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
"""Conv template heuristic for ROCm."""
# CPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "cpu")
@register_template_heuristic(conv3d_template.uid, "cpu")
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
"""Conv template heuristic for CPU."""
# XPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "xpu")
@register_template_heuristic(conv3d_template.uid, "xpu")
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
"""Conv template heuristic for XPU."""
# MTIA Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "mtia")
@register_template_heuristic(conv3d_template.uid, "mtia")
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
"""Conv template heuristic for MTIA."""