conv: refactor for lookup table support

\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

[ghstack-poisoned]
This commit is contained in:
Ruben Rodriguez Buchillon
2025-11-05 18:57:57 -08:00
parent 8e8cbb85ee
commit c277e07f77
6 changed files with 649 additions and 82 deletions

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.choices import InductorChoices
from torch._inductor.kernel_inputs import MMKernelInputs
from torch._inductor.kernel_inputs import ConvKernelInputs, MMKernelInputs
from torch._inductor.lookup_table.choices import LookupTableChoices
from torch._inductor.select_algorithm import (
add_preprocessing_fn,
@ -80,6 +80,39 @@ class MockMMKernelInputs(MMKernelInputs):
return self.tensors[0].device.type
class MockConvKernelInputs(ConvKernelInputs):
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[
dict[str, Union[float, int, tuple[Union[float, int], ...]]]
] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""Initialize with real tensors, creating mock nodes for the base class"""
mock_nodes = [MockTensorNode(t) for t in tensors]
super().__init__(
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
)
self.tensors = tensors # Keep reference to original tensors
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int shapes"""
return self.shapes_symbolic()
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int strides"""
return self.strides_symbolic() # pyre-ignore
@property
def device_type(self) -> Optional[str]:
return self.tensors[0].device.type
class BaseLookupTableTest(TestCase):
"""Base class for lookup table tests with common setup and utilities"""
@ -117,6 +150,32 @@ class BaseLookupTableTest(TestCase):
return MockMMKernelInputs(tensors, scalars)
def create_mock_conv_kernel_inputs(
self,
x_shape: tuple[int, ...] = (1, 3, 32, 32), # NCHW
weight_shape: tuple[int, ...] = (64, 3, 3, 3), # out_chan, in_chan, H, W
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float16,
stride: tuple[Union[float, int], ...] = (1, 1),
padding: tuple[Union[float, int], ...] = (1, 1),
dilation: tuple[Union[float, int], ...] = (1, 1),
groups: Union[float, int] = 1,
) -> MockConvKernelInputs:
"""Create MockConvKernelInputs for conv with real tensors"""
x = torch.randn(x_shape, device=device, dtype=dtype)
weight = torch.randn(weight_shape, device=device, dtype=dtype)
scalars = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": False,
"output_padding": (0, 0),
"groups": groups,
}
return MockConvKernelInputs([x, weight], scalars)
def create_lookup_key(self, method, kernel_inputs):
"""Create a lookup key using LookupTableChoices"""
choices = LookupTableChoices()
@ -1055,6 +1114,127 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
self.run_model("mm", tensors)
@fresh_cache()
def test_conv2d_lookup_table_entry_e2e(self):
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
import torch._inductor.kernel.conv
# Create input tensors with specific shapes for conv2d
# Input: [batch=2, in_channels=3, height=32, width=32]
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
# Make them channels-last to match what conv lowering uses
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
# Define conv parameters - use these SAME values everywhere
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
mock_scalars = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": False,
"output_padding": (0, 0),
"groups": groups,
}
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
# Create lookup key for "convolution" operation
choices_handler = LookupTableChoices()
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
# Get the exact template UID from conv2d_template
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
# Create a precisely configured conv2d config
# IMPORTANT: Only include per-config tunable parameters!
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
conv2d_config = {
"template_id": template_uid,
# Per-config tunable parameters only (what you'd tune via autotuning)
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 32,
"num_stages": 2,
"num_warps": 4,
}
# Setup lookup table
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
# Validation function to ensure our config is selected
def validate_conv_choice(choices):
# Should have exactly 1 choice (our lookup table entry)
assert len(choices) == 1, (
f"Expected 1 choice from lookup table, got {len(choices)}"
)
# Should be a TritonTemplateCaller
assert isinstance(choices[0], TritonTemplateCaller), (
f"Expected TritonTemplateCaller, got {type(choices[0])}"
)
# Name should contain "convolution2d" (from conv2d_template.name)
assert "convolution2d" in choices[0].name, (
f"Expected 'convolution2d' in name, got {choices[0].name}"
)
return choices
add_preprocessing_fn(validate_conv_choice)
# Create and compile the model using the SAME weight tensor
class SimpleConv2d(nn.Module):
def __init__(self, weight):
super().__init__()
# Register weight as buffer to use exact weight tensor
self.register_buffer("weight", weight)
def forward(self, x):
return torch.conv2d(
x,
self.weight,
bias=None,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
model = SimpleConv2d(weight).to(self.device)
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
compiled_model = torch.compile(model)
result = compiled_model(x) # Use the SAME x tensor
# Verify result shape is correct
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
# (same spatial dims due to padding=1, stride=1, kernel=3)
expected_shape = (2, 64, 32, 32)
self.assertEqual(
result.shape,
expected_shape,
f"Expected shape {expected_shape}, got {result.shape}",
)
# Verify no NaNs in output
self.assertFalse(
torch.isnan(result).any().item(),
"Output contains NaN values",
)
# Verify no Infs in output
self.assertFalse(
torch.isinf(result).any().item(),
"Output contains Inf values",
)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -8,6 +8,7 @@ import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
from .. import config, ir
from ..kernel_inputs import ConvKernelInputs
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -16,7 +17,9 @@ from ..lowering import (
)
from ..select_algorithm import (
autotune_select_algorithm,
ChoiceCaller,
ExternKernelChoice,
KernelTemplate,
SymbolicGridFn,
TritonTemplate,
)
@ -542,34 +545,40 @@ def convolution(
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
ordered_kwargs_for_cpp_kernel = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
if bias is None:
args = [x, weight]
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
else:
args = [x, weight, bias]
# Create ConvKernelInputs for unified template configuration
# Only include bias in input_nodes when it's not None
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
if bias is not None:
bias.realize()
bias.freeze_layout()
V.graph.sizevars.guard_int_seq(bias.get_size())
input_nodes = [x, weight, bias]
bias_idx = 2
else:
input_nodes = [x, weight]
bias_idx = None
kernel_inputs = ConvKernelInputs(
input_nodes,
scalars={
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
},
x_idx=0,
weight_idx=1,
bias_idx=bias_idx,
)
# Build list of templates to try
templates: list[ExternKernelChoice | KernelTemplate] = []
choices = []
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
choices = [
aten_convolution.bind(
args,
layout,
ordered_kwargs_for_cpp_kernel,
**kwargs,
)
]
templates.append(aten_convolution)
if (
torch._inductor.utils._use_conv_autotune_backend("TRITON")
@ -581,66 +590,30 @@ def convolution(
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
and V.graph.sizevars.statically_known_equals(in_chan * groups, x.get_size()[1]) # type: ignore[arg-type]
):
# 1x1 conv via mm
if (
is_ones(kernel_shape)
and is_ones(stride)
and is_zeros(padding)
and groups == 1
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
templates.append(aten_conv1x1_via_mm)
conv_configs = V.choices.get_conv_configs(device_type)
# Add appropriate template based on ndim
if ndim == 2:
templates.append(conv2d_template)
elif ndim == 3:
templates.append(conv3d_template)
dtype_size = x.get_dtype().itemsize
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
dtype_size=dtype_size,
):
if ndim == 2:
conv2d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_H=kernel_shape[0],
KERNEL_W=kernel_shape[1],
STRIDE_H=stride[0],
STRIDE_W=stride[1],
PADDING_H=padding[0],
PADDING_W=padding[1],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
elif ndim == 3:
conv3d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_D=kernel_shape[0],
KERNEL_H=kernel_shape[1],
KERNEL_W=kernel_shape[2],
STRIDE_D=stride[0],
STRIDE_H=stride[1],
STRIDE_W=stride[2],
PADDING_D=padding[0],
PADDING_H=padding[1],
PADDING_W=padding[2],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
# Initialize choices list and extend with template configs
choices: list[ChoiceCaller] = []
choices.extend(
V.choices.get_template_configs(
kernel_inputs,
templates,
"convolution",
)
)
if use_ck_conv_template(layout):
CKGroupedConvFwdTemplate.add_ck_conv_choices(
choices,
@ -652,7 +625,9 @@ def convolution(
groups=groups,
n_spatial_dimensions=ndim,
)
return autotune_select_algorithm("convolution", choices, args, layout)
return autotune_select_algorithm(
"convolution", choices, kernel_inputs.nodes(), layout
)
@register_lowering(aten._convolution)

View File

@ -27,7 +27,9 @@ class KernelInputs(ABC):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
] = None,
out_dtype: Optional[torch.dtype] = None,
):
"""
@ -183,7 +185,7 @@ class KernelInputs(ABC):
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]:
def get_scalar(self, name: str) -> Union[float, int, tuple[Union[float, int], ...]]:
"""
Get the scalar value for a given name.
@ -191,7 +193,7 @@ class KernelInputs(ABC):
name: Name of the scalar to get
Returns:
The scalar value
The scalar value (can be float, int, or tuple of float/int)
"""
assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name]
@ -216,7 +218,9 @@ class MMKernelInputs(KernelInputs):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
@ -336,3 +340,122 @@ class MMKernelInputs(KernelInputs):
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
return (m, n, k)
class ConvKernelInputs(KernelInputs):
"""
Specialized KernelInputs for convolution operations.
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
"""
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
] = None,
out_dtype: Optional[torch.dtype] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""
Initialize with convolution input nodes.
Args:
input_nodes: List containing [x, weight] or [x, weight, bias]
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
out_dtype: Optional output dtype
x_idx: Index of input tensor (default: 0)
weight_idx: Index of weight tensor (default: 1)
bias_idx: Index of bias tensor if present (default: None)
"""
super().__init__(input_nodes, scalars, out_dtype)
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
self._x_idx = x_idx
self._weight_idx = weight_idx
self._bias_idx = bias_idx
# Validate that required scalars are present
required_scalars = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
for key in required_scalars:
assert key in self._scalars, f"Conv requires scalar '{key}'"
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self._input_nodes[self._x_idx].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for convolution.
Args:
flexible: If True, return FlexibleLayout, otherwise FixedLayout
Returns:
Layout for the convolution output
"""
from torch._inductor.kernel.conv import conv_layout
x = self._input_nodes[self._x_idx]
weight = self._input_nodes[self._weight_idx]
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
# Extract conv params from scalars
stride: tuple[int] = self._scalars["stride"] # type: ignore[assignment]
padding: tuple[int] = self._scalars["padding"] # type: ignore[assignment]
dilation: tuple[int] = self._scalars["dilation"] # type: ignore[assignment]
transposed: bool = self._scalars["transposed"] # type: ignore[assignment]
output_padding: tuple[int] = self._scalars["output_padding"] # type: ignore[assignment]
groups: int = self._scalars["groups"] # type: ignore[assignment]
# Use existing conv_layout function
layout = conv_layout(
x,
weight,
bias,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
# TODO: Handle flexible vs fixed based on config if needed
return layout
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
"""
Get x, weight, and optional bias nodes.
Returns:
Tuple of (x, weight, bias) where bias may be None
"""
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
def spatial_dims(self) -> tuple[Any, ...]:
"""
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
Returns:
Tuple of spatial dimension sizes
"""
x_shape = self._input_nodes[self._x_idx].get_size()
return x_shape[2:] # Skip batch and channel dims

View File

@ -235,6 +235,10 @@ class LookupTableChoices(InductorChoices):
kernel_inputs, op_name
)
print(
f"device_key: {device_key}, device_agnostic_key: {device_agnostic_key}, keys: {lookup_table.keys()}"
)
config_list = []
for key_type, key in [

View File

@ -1,6 +1,6 @@
# NOTE: add new template heuristics here, so they get imported and registered
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
from . import aten, base, contiguous_mm, decompose_k, registry, triton
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
# expose the entry function
from .registry import get_template_heuristic

View File

@ -0,0 +1,285 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import torch
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
from ..kernel_inputs import ConvKernelInputs
from ..utils import is_ones, sympy_product
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
from .triton import (
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
if TYPE_CHECKING:
from collections.abc import Generator
from ..kernel_inputs import KernelInputs
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin for conv templates that converts config lists to template kwargs.
Similar to MMTemplateConfigMixin but for convolutions.
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
"""
# Type hint for methods from BaseConfigHeuristic
get_conv_configs: Any
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
Return template kwargs that don't change per-config.
These are derived from kernel_inputs and must include all template parameters.
Args:
kernel_inputs: ConvKernelInputs containing input tensors and conv params
op_name: Operation name (e.g., "convolution")
Returns:
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
ndim = len(kernel_shape)
# Extract scalars
stride = kernel_inputs.get_scalar("stride")
padding = kernel_inputs.get_scalar("padding")
groups = kernel_inputs.get_scalar("groups")
# Check if we should unroll (only for 1x1 kernels)
unroll = is_ones(kernel_shape)
# Build kwargs dict based on ndim
kwargs = {
"GROUPS": groups,
"UNROLL": unroll,
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
}
if ndim == 2:
kwargs.update(
{
"KERNEL_H": kernel_shape[0],
"KERNEL_W": kernel_shape[1],
"STRIDE_H": stride[0],
"STRIDE_W": stride[1],
"PADDING_H": padding[0],
"PADDING_W": padding[1],
}
)
elif ndim == 3:
kwargs.update(
{
"KERNEL_D": kernel_shape[0],
"KERNEL_H": kernel_shape[1],
"KERNEL_W": kernel_shape[2],
"STRIDE_D": stride[0],
"STRIDE_H": stride[1],
"STRIDE_W": stride[2],
"PADDING_D": padding[0],
"PADDING_H": padding[1],
"PADDING_W": padding[2],
}
)
return kwargs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
Args:
kernel_inputs: ConvKernelInputs containing input tensors
op_name: Operation name
Yields:
Dict of per-config kwargs for each configuration to try
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
"ConvTemplateConfigMixin requires ConvKernelInputs"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Calculate dimensions for heuristics
weight_size = weight.get_size()
out_chan = weight_size[0]
in_chan = weight_size[1]
# Batch * spatial dimensions product
x_size = x.get_size()
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
conv_configs_generator = self.get_conv_configs()
dtype_size = x.get_dtype().itemsize
# Generate configs (reusing mm preprocess_mm_configs machinery)
for c in conv_configs_generator(
batch_spatial_product,
out_chan,
in_chan,
dtype_size=dtype_size,
op_name="conv",
):
# Yield per-config kwargs
yield {
"BLOCK_M": c.kwargs.get("BLOCK_M"),
"BLOCK_N": c.kwargs.get("BLOCK_N"),
"BLOCK_K": c.kwargs.get("BLOCK_K"),
"num_stages": c.num_stages,
"num_warps": c.num_warps,
}
# ATEN convolution heuristic (no per-config tuning)
@register_template_heuristic(aten_convolution.uid, None)
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
"""
Pseudo heuristic for ATen convolution.
ATen doesn't have configs to tune - it's a single choice.
"""
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
# ATen doesn't have per-config kwargs to tune
yield dict()
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
"""
assert isinstance(kernel_inputs, ConvKernelInputs)
# Extract scalar values from kernel_inputs
stride = kernel_inputs.get_scalar("stride")
padding = kernel_inputs.get_scalar("padding")
dilation = kernel_inputs.get_scalar("dilation")
transposed = kernel_inputs.get_scalar("transposed")
output_padding = kernel_inputs.get_scalar("output_padding")
groups = kernel_inputs.get_scalar("groups")
# Check if bias is None to match old behavior
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
x, weight, bias = kernel_inputs.get_x_weight_bias()
kwargs = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
}
if bias is None:
# When bias is None, torch.convolution expects it as a kwarg
kwargs["bias"] = None
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"bias",
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
else:
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
return kwargs
# CUDA Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
"""Conv template heuristic for CUDA."""
# ROCm Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
"""Conv template heuristic for ROCm."""
# CPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "cpu")
@register_template_heuristic(conv3d_template.uid, "cpu")
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
"""Conv template heuristic for CPU."""
# XPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "xpu")
@register_template_heuristic(conv3d_template.uid, "xpu")
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
"""Conv template heuristic for XPU."""
# MTIA Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "mtia")
@register_template_heuristic(conv3d_template.uid, "mtia")
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
"""Conv template heuristic for MTIA."""