[inductor][choices] move extra kwargs out of get_template_configs (#163209)

# why

- extra kwargs are input/op dependent and not config dependent. We don't
  plan to serialize/deserialize them, and so they need to be fed in
  later beore making the KTC, rather than when getting the config values
  directly

# what

- move extra_kwargs into the KTC and get_ktc interface directly

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v -k "_addmm"
```

Differential Revision: [D82871310](https://our.internmc.facebook.com/intern/diff/D82871310)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163209
Approved by: https://github.com/nmacchioni
ghstack dependencies: #163305
This commit is contained in:
Ruben Rodriguez Buchillon
2025-09-19 17:24:43 -07:00
committed by PyTorch MergeBot
parent df5d6d57c9
commit 0ee331b523
3 changed files with 8 additions and 6 deletions

View File

@ -166,11 +166,13 @@ class InductorChoices:
# adjust the kernel inputs to the template-specific heuristic, if needed
# default here is to just return the kernel_inputs as is
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
# Create KernelTemplateChoice generator using the moved function
overrides = kwarg_overrides or {}
return make_ktc_generator(
template=template,
cs=cs,
extra_kwargs=extra_kwargs,
overrides=overrides,
layout=kernel_inputs.output_layout(),
inputs=inputs_val,

View File

@ -27,11 +27,13 @@ class KernelTemplateChoice:
self,
template: Union[KernelTemplate, ExternKernelChoice],
params: KernelTemplateParams,
extra_kwargs: dict[str, Any],
layout: Layout,
inputs: KernelInputs,
):
self.template = template
self.params = params
self.extra_kwargs = extra_kwargs
self.layout = layout
self.inputs = inputs
self.annotations: dict[str, Any] = {"ktc": self}
@ -53,6 +55,7 @@ class KernelTemplateChoice:
kwargs = self.params.to_kwargs()
self._choice = self.template.choice_or_none(
**kwargs,
**self.extra_kwargs,
layout=self.layout,
input_nodes=self.inputs.nodes(),
)
@ -64,6 +67,7 @@ class KernelTemplateChoice:
def make_ktc_generator(
template: Union[KernelTemplate, ExternKernelChoice],
cs: Generator[KernelTemplateParams, None, None],
extra_kwargs: dict[str, Any],
overrides: dict[str, Any],
layout: Layout,
inputs: KernelInputs,
@ -86,10 +90,10 @@ def make_ktc_generator(
base_kwargs = params.to_kwargs()
final_kwargs = {**base_kwargs, **overrides}
final_params = DictKernelTemplateParams(final_kwargs)
yield KernelTemplateChoice(
template=template,
params=final_params,
extra_kwargs=extra_kwargs,
layout=layout,
inputs=inputs,
)

View File

@ -39,14 +39,10 @@ class TemplateConfigHeuristics:
if not self.should_run(kernel_inputs):
return
# Get extra kwargs once
extra_kwargs = self.get_extra_kwargs(kernel_inputs, op_name)
# Generate configs and fuse with extra_kwargs
for config_dict in self._get_template_configs_impl(kernel_inputs, op_name):
# Fuse extra_kwargs into config
fused_kwargs = {**config_dict, **extra_kwargs}
yield DictKernelTemplateParams(fused_kwargs)
yield DictKernelTemplateParams(config_dict)
def _get_template_configs_impl(
self,