[inductor] V.choices.get_mm_configs override point (#161349)

# why

- enable us to override the default configs, or fall back to them
  through subclassing InductorChoices

# what

- override (private) function
- default implementationt takes the kernel template choice (ktc)
  generator for every template and just executes the generator
- future overrides can decide to replace those generators, or filter
  out choices

- the 2nd expensive step (maybe_append_choices, choice_or_none) is
  handled outside this function, in the main V.choices.get_mm_configs
  this means that any overriding benefits from not generating expensive
  templates that aren't going to be used

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```

Differential Revision: [D81520570](https://our.internmc.facebook.com/intern/diff/D81520570)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161349
Approved by: https://github.com/eellison
ghstack dependencies: #161347, #161348
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-08 17:11:11 -07:00
committed by PyTorch MergeBot
parent d3c4cf838e
commit 24a4dae85b

View File

@ -37,7 +37,8 @@ if TYPE_CHECKING:
from .codegen.common import KernelTemplate
from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import TritonKernel
from .ir import ChoiceCaller
from .ir import ChoiceCaller, Layout
from .kernel_template_choice import KernelTemplateChoice
from .select_algorithm import ExternKernelChoice
@ -104,6 +105,82 @@ class InductorChoices:
flex_heuristics = self.get_config_heuristics(device_type)
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
def _finalize_mm_configs(
self,
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
kernel_inputs: KernelInputs,
layout: Any,
templates: list[Union[KernelTemplate, ExternKernelChoice]],
op_name: str,
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
) -> list[KernelTemplateChoice]:
"""
This method can be subclassed to perform any override/modification of the choices.
The incoming parameters are cheap (generators), so you can do any overrides without
incurring too much cost. Override this method to customize the kernel template choices
before they are converted to ChoiceCaller objects, which is expensive on template codegen.
The full list of arguments are here to facilitate any overrides you may want to do,
as they can be used to start from scratch for each template if so desired.
Args:
template_choices: Dictionary mapping template UIDs to generators of KernelTemplateChoice objects
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
layout: Output layout
templates: List of template objects (KernelTemplate or ExternKernelChoice) in use
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm")
kwarg_overrides: Optional dict of kwargs to override for each template heuristic
Returns:
Flattened list of KernelTemplateChoice objects across all templates
"""
choices: list[KernelTemplateChoice] = []
for choice_gen in template_choices.values():
choices.extend(choice_gen)
return choices
def get_ktc(
self,
kernel_inputs: KernelInputs,
layout: Layout,
template: Union[KernelTemplate, ExternKernelChoice],
op_name: str,
kwarg_overrides: Optional[dict[str, Any]] = None,
) -> Generator[KernelTemplateChoice, None, None]:
"""
Utility to get the KernelTemplateChoice generator for a specific input.
This is a per template/op call, whereas get_mm_configs is an op wide call (all templates).
Consider when overriding/using at which level you need to make decisions
"""
# Extract device_type from kernel_inputs
device_type = kernel_inputs.device_type
assert device_type is not None, "get_mm_configs requires a valid device type"
# Extract template_name from the template object
template_name = template.uid
# Get the appropriate template-specific heuristic
heuristic = get_template_heuristic(template_name, device_type, op_name)
cs = heuristic.get_template_configs(
kernel_inputs,
layout,
op_name,
)
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
# adjust the kernel inputs to the template-specific heuristic, if needed
# default here is to just return the kernel_inputs as is
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
# Create KernelTemplateChoice generator using the moved function
overrides = kwarg_overrides or {}
return make_ktc_generator(
template=template,
cs=cs,
overrides=overrides,
extra_kwargs=extra_kwargs,
layout=layout,
inputs=inputs_val,
)
def get_mm_configs(
self,
kernel_inputs: KernelInputs,
@ -131,54 +208,31 @@ class InductorChoices:
if len(input_tensors) < 2:
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
# Extract device_type from kernel_inputs
device_type = kernel_inputs.device_type
assert device_type is not None, "get_mm_configs requires a valid device type"
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
template_choices = {}
for template in templates:
# Extract template_name from the template object
template_name = template.uid
# Get the appropriate template-specific heuristic
heuristic = get_template_heuristic(template_name, device_type, op_name)
cs = heuristic.get_template_configs(
template_choices[template.uid] = self.get_ktc(
kernel_inputs,
layout,
template,
op_name,
)
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
# Extract layout and input_nodes from extra_kwargs to pass them explicitly
layout_val = layout
# adjust the kernel inputs to the template-specific heuristic, if needed
# default here is to just return the kernel_inputs as is
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
# Get overrides for this specific template
overrides = kwarg_overrides.get(template.uid, {})
# Create KernelTemplateChoice generator using the moved function
choice_gen = make_ktc_generator(
template=template,
cs=cs,
overrides=overrides,
extra_kwargs=extra_kwargs,
layout=layout_val,
inputs=inputs_val,
kwarg_overrides.get(template.uid, {}),
)
template_choices[template.uid] = choice_gen
# Second pass: Iterate through templates in original order and collect choices
# Second pass: Adjust the template choices
adjusted_choices = self._finalize_mm_configs(
template_choices,
kernel_inputs,
layout,
templates,
op_name,
kwarg_overrides,
)
choices = []
for template in templates:
choice_gen = template_choices[template.uid]
for ktc in choice_gen:
if ktc.choice is not None:
choices.append(ktc.choice)
# Third pass: Get adjusted choices and collect non-None ChoiceCaller objects
for ktc in adjusted_choices:
if ktc.choice is not None:
choices.append(ktc.choice)
return choices