mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] V.choices.get_mm_configs override point (#161349)
# why - enable us to override the default configs, or fall back to them through subclassing InductorChoices # what - override (private) function - default implementationt takes the kernel template choice (ktc) generator for every template and just executes the generator - future overrides can decide to replace those generators, or filter out choices - the 2nd expensive step (maybe_append_choices, choice_or_none) is handled outside this function, in the main V.choices.get_mm_configs this means that any overriding benefits from not generating expensive templates that aren't going to be used # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520570](https://our.internmc.facebook.com/intern/diff/D81520570) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161349 Approved by: https://github.com/eellison ghstack dependencies: #161347, #161348
This commit is contained in:
committed by
PyTorch MergeBot
parent
d3c4cf838e
commit
24a4dae85b
@ -37,7 +37,8 @@ if TYPE_CHECKING:
|
||||
from .codegen.common import KernelTemplate
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
from .codegen.triton import TritonKernel
|
||||
from .ir import ChoiceCaller
|
||||
from .ir import ChoiceCaller, Layout
|
||||
from .kernel_template_choice import KernelTemplateChoice
|
||||
from .select_algorithm import ExternKernelChoice
|
||||
|
||||
|
||||
@ -104,6 +105,82 @@ class InductorChoices:
|
||||
flex_heuristics = self.get_config_heuristics(device_type)
|
||||
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
|
||||
|
||||
def _finalize_mm_configs(
|
||||
self,
|
||||
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Any,
|
||||
templates: list[Union[KernelTemplate, ExternKernelChoice]],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
|
||||
) -> list[KernelTemplateChoice]:
|
||||
"""
|
||||
This method can be subclassed to perform any override/modification of the choices.
|
||||
The incoming parameters are cheap (generators), so you can do any overrides without
|
||||
incurring too much cost. Override this method to customize the kernel template choices
|
||||
before they are converted to ChoiceCaller objects, which is expensive on template codegen.
|
||||
|
||||
The full list of arguments are here to facilitate any overrides you may want to do,
|
||||
as they can be used to start from scratch for each template if so desired.
|
||||
|
||||
Args:
|
||||
template_choices: Dictionary mapping template UIDs to generators of KernelTemplateChoice objects
|
||||
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
|
||||
layout: Output layout
|
||||
templates: List of template objects (KernelTemplate or ExternKernelChoice) in use
|
||||
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm")
|
||||
kwarg_overrides: Optional dict of kwargs to override for each template heuristic
|
||||
|
||||
Returns:
|
||||
Flattened list of KernelTemplateChoice objects across all templates
|
||||
"""
|
||||
choices: list[KernelTemplateChoice] = []
|
||||
for choice_gen in template_choices.values():
|
||||
choices.extend(choice_gen)
|
||||
return choices
|
||||
|
||||
def get_ktc(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
layout: Layout,
|
||||
template: Union[KernelTemplate, ExternKernelChoice],
|
||||
op_name: str,
|
||||
kwarg_overrides: Optional[dict[str, Any]] = None,
|
||||
) -> Generator[KernelTemplateChoice, None, None]:
|
||||
"""
|
||||
Utility to get the KernelTemplateChoice generator for a specific input.
|
||||
|
||||
This is a per template/op call, whereas get_mm_configs is an op wide call (all templates).
|
||||
Consider when overriding/using at which level you need to make decisions
|
||||
"""
|
||||
# Extract device_type from kernel_inputs
|
||||
device_type = kernel_inputs.device_type
|
||||
assert device_type is not None, "get_mm_configs requires a valid device type"
|
||||
# Extract template_name from the template object
|
||||
template_name = template.uid
|
||||
|
||||
# Get the appropriate template-specific heuristic
|
||||
heuristic = get_template_heuristic(template_name, device_type, op_name)
|
||||
cs = heuristic.get_template_configs(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
op_name,
|
||||
)
|
||||
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
|
||||
# adjust the kernel inputs to the template-specific heuristic, if needed
|
||||
# default here is to just return the kernel_inputs as is
|
||||
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
|
||||
# Create KernelTemplateChoice generator using the moved function
|
||||
overrides = kwarg_overrides or {}
|
||||
return make_ktc_generator(
|
||||
template=template,
|
||||
cs=cs,
|
||||
overrides=overrides,
|
||||
extra_kwargs=extra_kwargs,
|
||||
layout=layout,
|
||||
inputs=inputs_val,
|
||||
)
|
||||
|
||||
def get_mm_configs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
@ -131,54 +208,31 @@ class InductorChoices:
|
||||
if len(input_tensors) < 2:
|
||||
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
|
||||
|
||||
# Extract device_type from kernel_inputs
|
||||
device_type = kernel_inputs.device_type
|
||||
assert device_type is not None, "get_mm_configs requires a valid device type"
|
||||
|
||||
# First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
|
||||
template_choices = {}
|
||||
for template in templates:
|
||||
# Extract template_name from the template object
|
||||
template_name = template.uid
|
||||
|
||||
# Get the appropriate template-specific heuristic
|
||||
heuristic = get_template_heuristic(template_name, device_type, op_name)
|
||||
|
||||
cs = heuristic.get_template_configs(
|
||||
template_choices[template.uid] = self.get_ktc(
|
||||
kernel_inputs,
|
||||
layout,
|
||||
template,
|
||||
op_name,
|
||||
)
|
||||
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
|
||||
|
||||
# Extract layout and input_nodes from extra_kwargs to pass them explicitly
|
||||
layout_val = layout
|
||||
# adjust the kernel inputs to the template-specific heuristic, if needed
|
||||
# default here is to just return the kernel_inputs as is
|
||||
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
|
||||
|
||||
# Get overrides for this specific template
|
||||
overrides = kwarg_overrides.get(template.uid, {})
|
||||
|
||||
# Create KernelTemplateChoice generator using the moved function
|
||||
choice_gen = make_ktc_generator(
|
||||
template=template,
|
||||
cs=cs,
|
||||
overrides=overrides,
|
||||
extra_kwargs=extra_kwargs,
|
||||
layout=layout_val,
|
||||
inputs=inputs_val,
|
||||
kwarg_overrides.get(template.uid, {}),
|
||||
)
|
||||
|
||||
template_choices[template.uid] = choice_gen
|
||||
|
||||
# Second pass: Iterate through templates in original order and collect choices
|
||||
# Second pass: Adjust the template choices
|
||||
adjusted_choices = self._finalize_mm_configs(
|
||||
template_choices,
|
||||
kernel_inputs,
|
||||
layout,
|
||||
templates,
|
||||
op_name,
|
||||
kwarg_overrides,
|
||||
)
|
||||
choices = []
|
||||
for template in templates:
|
||||
choice_gen = template_choices[template.uid]
|
||||
for ktc in choice_gen:
|
||||
if ktc.choice is not None:
|
||||
choices.append(ktc.choice)
|
||||
# Third pass: Get adjusted choices and collect non-None ChoiceCaller objects
|
||||
for ktc in adjusted_choices:
|
||||
if ktc.choice is not None:
|
||||
choices.append(ktc.choice)
|
||||
|
||||
return choices
|
||||
|
||||
|
Reference in New Issue
Block a user