mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 06:07:55 +08:00
conv: refactor for lookup table support
\# why enable configuring conv operations through the lookup table \# what - move kwargs etc into template_heuristics - add conv specific kernel inputs - add lookup table e2e test for conv \# testing ``` python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v ``` [ghstack-poisoned]
This commit is contained in:
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.kernel_inputs import MMKernelInputs
|
||||
from torch._inductor.kernel_inputs import ConvKernelInputs, MMKernelInputs
|
||||
from torch._inductor.lookup_table.choices import LookupTableChoices
|
||||
from torch._inductor.select_algorithm import (
|
||||
add_preprocessing_fn,
|
||||
@ -80,6 +80,39 @@ class MockMMKernelInputs(MMKernelInputs):
|
||||
return self.tensors[0].device.type
|
||||
|
||||
|
||||
class MockConvKernelInputs(ConvKernelInputs):
|
||||
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensors: list[torch.Tensor],
|
||||
scalars: Optional[
|
||||
dict[str, Union[float, int, tuple[Union[float, int], ...]]]
|
||||
] = None,
|
||||
x_idx: int = 0,
|
||||
weight_idx: int = 1,
|
||||
bias_idx: Optional[int] = None,
|
||||
):
|
||||
"""Initialize with real tensors, creating mock nodes for the base class"""
|
||||
mock_nodes = [MockTensorNode(t) for t in tensors]
|
||||
super().__init__(
|
||||
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
|
||||
)
|
||||
self.tensors = tensors # Keep reference to original tensors
|
||||
|
||||
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""Delegate to symbolic since real tensors already have int shapes"""
|
||||
return self.shapes_symbolic()
|
||||
|
||||
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""Delegate to symbolic since real tensors already have int strides"""
|
||||
return self.strides_symbolic() # pyre-ignore
|
||||
|
||||
@property
|
||||
def device_type(self) -> Optional[str]:
|
||||
return self.tensors[0].device.type
|
||||
|
||||
|
||||
class BaseLookupTableTest(TestCase):
|
||||
"""Base class for lookup table tests with common setup and utilities"""
|
||||
|
||||
@ -117,6 +150,32 @@ class BaseLookupTableTest(TestCase):
|
||||
|
||||
return MockMMKernelInputs(tensors, scalars)
|
||||
|
||||
def create_mock_conv_kernel_inputs(
|
||||
self,
|
||||
x_shape: tuple[int, ...] = (1, 3, 32, 32), # NCHW
|
||||
weight_shape: tuple[int, ...] = (64, 3, 3, 3), # out_chan, in_chan, H, W
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.float16,
|
||||
stride: tuple[Union[float, int], ...] = (1, 1),
|
||||
padding: tuple[Union[float, int], ...] = (1, 1),
|
||||
dilation: tuple[Union[float, int], ...] = (1, 1),
|
||||
groups: Union[float, int] = 1,
|
||||
) -> MockConvKernelInputs:
|
||||
"""Create MockConvKernelInputs for conv with real tensors"""
|
||||
x = torch.randn(x_shape, device=device, dtype=dtype)
|
||||
weight = torch.randn(weight_shape, device=device, dtype=dtype)
|
||||
|
||||
scalars = {
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": False,
|
||||
"output_padding": (0, 0),
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
return MockConvKernelInputs([x, weight], scalars)
|
||||
|
||||
def create_lookup_key(self, method, kernel_inputs):
|
||||
"""Create a lookup key using LookupTableChoices"""
|
||||
choices = LookupTableChoices()
|
||||
@ -1055,6 +1114,127 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
|
||||
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
|
||||
self.run_model("mm", tensors)
|
||||
|
||||
@fresh_cache()
|
||||
def test_conv2d_lookup_table_entry_e2e(self):
|
||||
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
|
||||
import torch._inductor.kernel.conv
|
||||
|
||||
# Create input tensors with specific shapes for conv2d
|
||||
# Input: [batch=2, in_channels=3, height=32, width=32]
|
||||
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
|
||||
# Make them channels-last to match what conv lowering uses
|
||||
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
|
||||
# Define conv parameters - use these SAME values everywhere
|
||||
stride = (1, 1)
|
||||
padding = (1, 1)
|
||||
dilation = (1, 1)
|
||||
groups = 1
|
||||
|
||||
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
|
||||
mock_scalars = {
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": False,
|
||||
"output_padding": (0, 0),
|
||||
"groups": groups,
|
||||
}
|
||||
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
|
||||
|
||||
# Create lookup key for "convolution" operation
|
||||
choices_handler = LookupTableChoices()
|
||||
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
|
||||
|
||||
# Get the exact template UID from conv2d_template
|
||||
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
|
||||
|
||||
# Create a precisely configured conv2d config
|
||||
# IMPORTANT: Only include per-config tunable parameters!
|
||||
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
|
||||
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
|
||||
conv2d_config = {
|
||||
"template_id": template_uid,
|
||||
# Per-config tunable parameters only (what you'd tune via autotuning)
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"BLOCK_K": 32,
|
||||
"num_stages": 2,
|
||||
"num_warps": 4,
|
||||
}
|
||||
|
||||
# Setup lookup table
|
||||
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
|
||||
|
||||
# Validation function to ensure our config is selected
|
||||
def validate_conv_choice(choices):
|
||||
# Should have exactly 1 choice (our lookup table entry)
|
||||
assert len(choices) == 1, (
|
||||
f"Expected 1 choice from lookup table, got {len(choices)}"
|
||||
)
|
||||
# Should be a TritonTemplateCaller
|
||||
assert isinstance(choices[0], TritonTemplateCaller), (
|
||||
f"Expected TritonTemplateCaller, got {type(choices[0])}"
|
||||
)
|
||||
# Name should contain "convolution2d" (from conv2d_template.name)
|
||||
assert "convolution2d" in choices[0].name, (
|
||||
f"Expected 'convolution2d' in name, got {choices[0].name}"
|
||||
)
|
||||
return choices
|
||||
|
||||
add_preprocessing_fn(validate_conv_choice)
|
||||
|
||||
# Create and compile the model using the SAME weight tensor
|
||||
class SimpleConv2d(nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
# Register weight as buffer to use exact weight tensor
|
||||
self.register_buffer("weight", weight)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
bias=None,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
model = SimpleConv2d(weight).to(self.device)
|
||||
|
||||
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
|
||||
compiled_model = torch.compile(model)
|
||||
result = compiled_model(x) # Use the SAME x tensor
|
||||
|
||||
# Verify result shape is correct
|
||||
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
|
||||
# (same spatial dims due to padding=1, stride=1, kernel=3)
|
||||
expected_shape = (2, 64, 32, 32)
|
||||
self.assertEqual(
|
||||
result.shape,
|
||||
expected_shape,
|
||||
f"Expected shape {expected_shape}, got {result.shape}",
|
||||
)
|
||||
|
||||
# Verify no NaNs in output
|
||||
self.assertFalse(
|
||||
torch.isnan(result).any().item(),
|
||||
"Output contains NaN values",
|
||||
)
|
||||
|
||||
# Verify no Infs in output
|
||||
self.assertFalse(
|
||||
torch.isinf(result).any().item(),
|
||||
"Output contains Inf values",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
||||
|
||||
from .. import config, ir
|
||||
from ..kernel_inputs import ConvKernelInputs
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
@ -16,7 +17,9 @@ from ..lowering import (
|
||||
)
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ChoiceCaller,
|
||||
ExternKernelChoice,
|
||||
KernelTemplate,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
@ -542,34 +545,40 @@ def convolution(
|
||||
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
|
||||
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
|
||||
|
||||
ordered_kwargs_for_cpp_kernel = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
if bias is None:
|
||||
args = [x, weight]
|
||||
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
|
||||
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
|
||||
else:
|
||||
args = [x, weight, bias]
|
||||
# Create ConvKernelInputs for unified template configuration
|
||||
# Only include bias in input_nodes when it's not None
|
||||
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
|
||||
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
|
||||
if bias is not None:
|
||||
bias.realize()
|
||||
bias.freeze_layout()
|
||||
V.graph.sizevars.guard_int_seq(bias.get_size())
|
||||
input_nodes = [x, weight, bias]
|
||||
bias_idx = 2
|
||||
else:
|
||||
input_nodes = [x, weight]
|
||||
bias_idx = None
|
||||
|
||||
kernel_inputs = ConvKernelInputs(
|
||||
input_nodes,
|
||||
scalars={
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": transposed,
|
||||
"output_padding": output_padding,
|
||||
"groups": groups,
|
||||
},
|
||||
x_idx=0,
|
||||
weight_idx=1,
|
||||
bias_idx=bias_idx,
|
||||
)
|
||||
|
||||
# Build list of templates to try
|
||||
templates: list[ExternKernelChoice | KernelTemplate] = []
|
||||
|
||||
choices = []
|
||||
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
|
||||
choices = [
|
||||
aten_convolution.bind(
|
||||
args,
|
||||
layout,
|
||||
ordered_kwargs_for_cpp_kernel,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
templates.append(aten_convolution)
|
||||
|
||||
if (
|
||||
torch._inductor.utils._use_conv_autotune_backend("TRITON")
|
||||
@ -581,66 +590,30 @@ def convolution(
|
||||
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
||||
and V.graph.sizevars.statically_known_equals(in_chan * groups, x.get_size()[1]) # type: ignore[arg-type]
|
||||
):
|
||||
# 1x1 conv via mm
|
||||
if (
|
||||
is_ones(kernel_shape)
|
||||
and is_ones(stride)
|
||||
and is_zeros(padding)
|
||||
and groups == 1
|
||||
):
|
||||
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
||||
templates.append(aten_conv1x1_via_mm)
|
||||
|
||||
conv_configs = V.choices.get_conv_configs(device_type)
|
||||
# Add appropriate template based on ndim
|
||||
if ndim == 2:
|
||||
templates.append(conv2d_template)
|
||||
elif ndim == 3:
|
||||
templates.append(conv3d_template)
|
||||
|
||||
dtype_size = x.get_dtype().itemsize
|
||||
for cfg in conv_configs(
|
||||
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
||||
out_chan,
|
||||
in_chan,
|
||||
dtype_size=dtype_size,
|
||||
):
|
||||
if ndim == 2:
|
||||
conv2d_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(x, weight),
|
||||
layout=layout,
|
||||
KERNEL_H=kernel_shape[0],
|
||||
KERNEL_W=kernel_shape[1],
|
||||
STRIDE_H=stride[0],
|
||||
STRIDE_W=stride[1],
|
||||
PADDING_H=padding[0],
|
||||
PADDING_W=padding[1],
|
||||
GROUPS=groups,
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
)
|
||||
elif ndim == 3:
|
||||
conv3d_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(x, weight),
|
||||
layout=layout,
|
||||
KERNEL_D=kernel_shape[0],
|
||||
KERNEL_H=kernel_shape[1],
|
||||
KERNEL_W=kernel_shape[2],
|
||||
STRIDE_D=stride[0],
|
||||
STRIDE_H=stride[1],
|
||||
STRIDE_W=stride[2],
|
||||
PADDING_D=padding[0],
|
||||
PADDING_H=padding[1],
|
||||
PADDING_W=padding[2],
|
||||
GROUPS=groups,
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
)
|
||||
# Initialize choices list and extend with template configs
|
||||
choices: list[ChoiceCaller] = []
|
||||
choices.extend(
|
||||
V.choices.get_template_configs(
|
||||
kernel_inputs,
|
||||
templates,
|
||||
"convolution",
|
||||
)
|
||||
)
|
||||
if use_ck_conv_template(layout):
|
||||
CKGroupedConvFwdTemplate.add_ck_conv_choices(
|
||||
choices,
|
||||
@ -652,7 +625,9 @@ def convolution(
|
||||
groups=groups,
|
||||
n_spatial_dimensions=ndim,
|
||||
)
|
||||
return autotune_select_algorithm("convolution", choices, args, layout)
|
||||
return autotune_select_algorithm(
|
||||
"convolution", choices, kernel_inputs.nodes(), layout
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten._convolution)
|
||||
|
||||
@ -27,7 +27,9 @@ class KernelInputs(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[
|
||||
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
|
||||
] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
@ -183,7 +185,7 @@ class KernelInputs(ABC):
|
||||
The output dtype
|
||||
"""
|
||||
|
||||
def get_scalar(self, name: str) -> Union[float, int]:
|
||||
def get_scalar(self, name: str) -> Union[float, int, tuple[Union[float, int], ...]]:
|
||||
"""
|
||||
Get the scalar value for a given name.
|
||||
|
||||
@ -191,7 +193,7 @@ class KernelInputs(ABC):
|
||||
name: Name of the scalar to get
|
||||
|
||||
Returns:
|
||||
The scalar value
|
||||
The scalar value (can be float, int, or tuple of float/int)
|
||||
"""
|
||||
assert name in self._scalars, f"Scalar {name} not found, but required"
|
||||
return self._scalars[name]
|
||||
@ -216,7 +218,9 @@ class MMKernelInputs(KernelInputs):
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[
|
||||
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
|
||||
] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
mat1_idx: int = -2,
|
||||
mat2_idx: int = -1,
|
||||
@ -336,3 +340,122 @@ class MMKernelInputs(KernelInputs):
|
||||
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
|
||||
|
||||
return (m, n, k)
|
||||
|
||||
|
||||
class ConvKernelInputs(KernelInputs):
|
||||
"""
|
||||
Specialized KernelInputs for convolution operations.
|
||||
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[
|
||||
dict[str, Union[float, int, bool, tuple[Union[float, bool, int], ...]]]
|
||||
] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
x_idx: int = 0,
|
||||
weight_idx: int = 1,
|
||||
bias_idx: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize with convolution input nodes.
|
||||
|
||||
Args:
|
||||
input_nodes: List containing [x, weight] or [x, weight, bias]
|
||||
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
|
||||
out_dtype: Optional output dtype
|
||||
x_idx: Index of input tensor (default: 0)
|
||||
weight_idx: Index of weight tensor (default: 1)
|
||||
bias_idx: Index of bias tensor if present (default: None)
|
||||
"""
|
||||
super().__init__(input_nodes, scalars, out_dtype)
|
||||
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
|
||||
|
||||
self._x_idx = x_idx
|
||||
self._weight_idx = weight_idx
|
||||
self._bias_idx = bias_idx
|
||||
|
||||
# Validate that required scalars are present
|
||||
required_scalars = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
for key in required_scalars:
|
||||
assert key in self._scalars, f"Conv requires scalar '{key}'"
|
||||
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
"""
|
||||
Get the output dtype, whether passed in or inferred from the nodes
|
||||
|
||||
Returns:
|
||||
The output dtype
|
||||
"""
|
||||
if self._out_dtype is not None:
|
||||
return self._out_dtype
|
||||
return self._input_nodes[self._x_idx].get_dtype()
|
||||
|
||||
def output_layout(self, flexible: bool = True) -> Layout:
|
||||
"""
|
||||
Handle output layout generation for convolution.
|
||||
|
||||
Args:
|
||||
flexible: If True, return FlexibleLayout, otherwise FixedLayout
|
||||
|
||||
Returns:
|
||||
Layout for the convolution output
|
||||
"""
|
||||
from torch._inductor.kernel.conv import conv_layout
|
||||
|
||||
x = self._input_nodes[self._x_idx]
|
||||
weight = self._input_nodes[self._weight_idx]
|
||||
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
|
||||
|
||||
# Extract conv params from scalars
|
||||
stride: tuple[int] = self._scalars["stride"] # type: ignore[assignment]
|
||||
padding: tuple[int] = self._scalars["padding"] # type: ignore[assignment]
|
||||
dilation: tuple[int] = self._scalars["dilation"] # type: ignore[assignment]
|
||||
transposed: bool = self._scalars["transposed"] # type: ignore[assignment]
|
||||
output_padding: tuple[int] = self._scalars["output_padding"] # type: ignore[assignment]
|
||||
groups: int = self._scalars["groups"] # type: ignore[assignment]
|
||||
|
||||
# Use existing conv_layout function
|
||||
layout = conv_layout(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
)
|
||||
|
||||
# TODO: Handle flexible vs fixed based on config if needed
|
||||
return layout
|
||||
|
||||
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
|
||||
"""
|
||||
Get x, weight, and optional bias nodes.
|
||||
|
||||
Returns:
|
||||
Tuple of (x, weight, bias) where bias may be None
|
||||
"""
|
||||
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
|
||||
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
|
||||
|
||||
def spatial_dims(self) -> tuple[Any, ...]:
|
||||
"""
|
||||
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
|
||||
|
||||
Returns:
|
||||
Tuple of spatial dimension sizes
|
||||
"""
|
||||
x_shape = self._input_nodes[self._x_idx].get_size()
|
||||
return x_shape[2:] # Skip batch and channel dims
|
||||
|
||||
@ -235,6 +235,10 @@ class LookupTableChoices(InductorChoices):
|
||||
kernel_inputs, op_name
|
||||
)
|
||||
|
||||
print(
|
||||
f"device_key: {device_key}, device_agnostic_key: {device_agnostic_key}, keys: {lookup_table.keys()}"
|
||||
)
|
||||
|
||||
config_list = []
|
||||
|
||||
for key_type, key in [
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# NOTE: add new template heuristics here, so they get imported and registered
|
||||
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
|
||||
from . import aten, base, contiguous_mm, decompose_k, registry, triton
|
||||
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
|
||||
|
||||
# expose the entry function
|
||||
from .registry import get_template_heuristic
|
||||
|
||||
285
torch/_inductor/template_heuristics/conv.py
Normal file
285
torch/_inductor/template_heuristics/conv.py
Normal file
@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
|
||||
from ..kernel_inputs import ConvKernelInputs
|
||||
from ..utils import is_ones, sympy_product
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
from .triton import (
|
||||
CPUConfigHeuristic,
|
||||
CUDAConfigHeuristic,
|
||||
MTIAConfigHeuristic,
|
||||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from ..kernel_inputs import KernelInputs
|
||||
|
||||
|
||||
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
"""
|
||||
Mixin for conv templates that converts config lists to template kwargs.
|
||||
Similar to MMTemplateConfigMixin but for convolutions.
|
||||
|
||||
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
|
||||
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
|
||||
"""
|
||||
|
||||
# Type hint for methods from BaseConfigHeuristic
|
||||
get_conv_configs: Any
|
||||
|
||||
def get_extra_kwargs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Return template kwargs that don't change per-config.
|
||||
These are derived from kernel_inputs and must include all template parameters.
|
||||
|
||||
Args:
|
||||
kernel_inputs: ConvKernelInputs containing input tensors and conv params
|
||||
op_name: Operation name (e.g., "convolution")
|
||||
|
||||
Returns:
|
||||
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs), (
|
||||
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
|
||||
)
|
||||
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
|
||||
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
|
||||
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
|
||||
ndim = len(kernel_shape)
|
||||
|
||||
# Extract scalars
|
||||
stride = kernel_inputs.get_scalar("stride")
|
||||
padding = kernel_inputs.get_scalar("padding")
|
||||
groups = kernel_inputs.get_scalar("groups")
|
||||
|
||||
# Check if we should unroll (only for 1x1 kernels)
|
||||
unroll = is_ones(kernel_shape)
|
||||
|
||||
# Build kwargs dict based on ndim
|
||||
kwargs = {
|
||||
"GROUPS": groups,
|
||||
"UNROLL": unroll,
|
||||
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
|
||||
}
|
||||
|
||||
if ndim == 2:
|
||||
kwargs.update(
|
||||
{
|
||||
"KERNEL_H": kernel_shape[0],
|
||||
"KERNEL_W": kernel_shape[1],
|
||||
"STRIDE_H": stride[0],
|
||||
"STRIDE_W": stride[1],
|
||||
"PADDING_H": padding[0],
|
||||
"PADDING_W": padding[1],
|
||||
}
|
||||
)
|
||||
elif ndim == 3:
|
||||
kwargs.update(
|
||||
{
|
||||
"KERNEL_D": kernel_shape[0],
|
||||
"KERNEL_H": kernel_shape[1],
|
||||
"KERNEL_W": kernel_shape[2],
|
||||
"STRIDE_D": stride[0],
|
||||
"STRIDE_H": stride[1],
|
||||
"STRIDE_W": stride[2],
|
||||
"PADDING_D": padding[0],
|
||||
"PADDING_H": padding[1],
|
||||
"PADDING_W": padding[2],
|
||||
}
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
|
||||
|
||||
Args:
|
||||
kernel_inputs: ConvKernelInputs containing input tensors
|
||||
op_name: Operation name
|
||||
|
||||
Yields:
|
||||
Dict of per-config kwargs for each configuration to try
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs), (
|
||||
"ConvTemplateConfigMixin requires ConvKernelInputs"
|
||||
)
|
||||
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
# Calculate dimensions for heuristics
|
||||
weight_size = weight.get_size()
|
||||
out_chan = weight_size[0]
|
||||
in_chan = weight_size[1]
|
||||
|
||||
# Batch * spatial dimensions product
|
||||
x_size = x.get_size()
|
||||
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
|
||||
|
||||
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
|
||||
conv_configs_generator = self.get_conv_configs()
|
||||
|
||||
dtype_size = x.get_dtype().itemsize
|
||||
|
||||
# Generate configs (reusing mm preprocess_mm_configs machinery)
|
||||
for c in conv_configs_generator(
|
||||
batch_spatial_product,
|
||||
out_chan,
|
||||
in_chan,
|
||||
dtype_size=dtype_size,
|
||||
op_name="conv",
|
||||
):
|
||||
# Yield per-config kwargs
|
||||
yield {
|
||||
"BLOCK_M": c.kwargs.get("BLOCK_M"),
|
||||
"BLOCK_N": c.kwargs.get("BLOCK_N"),
|
||||
"BLOCK_K": c.kwargs.get("BLOCK_K"),
|
||||
"num_stages": c.num_stages,
|
||||
"num_warps": c.num_warps,
|
||||
}
|
||||
|
||||
|
||||
# ATEN convolution heuristic (no per-config tuning)
|
||||
@register_template_heuristic(aten_convolution.uid, None)
|
||||
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
|
||||
"""
|
||||
Pseudo heuristic for ATen convolution.
|
||||
ATen doesn't have configs to tune - it's a single choice.
|
||||
"""
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
# ATen doesn't have per-config kwargs to tune
|
||||
yield dict()
|
||||
|
||||
def get_extra_kwargs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs)
|
||||
|
||||
# Extract scalar values from kernel_inputs
|
||||
stride = kernel_inputs.get_scalar("stride")
|
||||
padding = kernel_inputs.get_scalar("padding")
|
||||
dilation = kernel_inputs.get_scalar("dilation")
|
||||
transposed = kernel_inputs.get_scalar("transposed")
|
||||
output_padding = kernel_inputs.get_scalar("output_padding")
|
||||
groups = kernel_inputs.get_scalar("groups")
|
||||
|
||||
# Check if bias is None to match old behavior
|
||||
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
|
||||
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
kwargs = {
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": transposed,
|
||||
"output_padding": output_padding,
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
if bias is None:
|
||||
# When bias is None, torch.convolution expects it as a kwarg
|
||||
kwargs["bias"] = None
|
||||
kwargs["ordered_kwargs_for_cpp_kernel"] = [
|
||||
"bias",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
else:
|
||||
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
|
||||
kwargs["ordered_kwargs_for_cpp_kernel"] = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
# CUDA Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(
|
||||
conv2d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
conv3d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Conv template heuristic for CUDA."""
|
||||
|
||||
|
||||
# ROCm Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(
|
||||
conv2d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
conv3d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Conv template heuristic for ROCm."""
|
||||
|
||||
|
||||
# CPU Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "cpu")
|
||||
@register_template_heuristic(conv3d_template.uid, "cpu")
|
||||
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Conv template heuristic for CPU."""
|
||||
|
||||
|
||||
# XPU Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "xpu")
|
||||
@register_template_heuristic(conv3d_template.uid, "xpu")
|
||||
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Conv template heuristic for XPU."""
|
||||
|
||||
|
||||
# MTIA Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "mtia")
|
||||
@register_template_heuristic(conv3d_template.uid, "mtia")
|
||||
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Conv template heuristic for MTIA."""
|
||||
Reference in New Issue
Block a user