Files
pytorch/torch/_inductor/kernel_template_choice.py
Ruben Rodriguez Buchillon 0ee331b523 [inductor][choices] move extra kwargs out of get_template_configs (#163209)
# why

- extra kwargs are input/op dependent and not config dependent. We don't
  plan to serialize/deserialize them, and so they need to be fed in
  later beore making the KTC, rather than when getting the config values
  directly

# what

- move extra_kwargs into the KTC and get_ktc interface directly

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v -k "_addmm"
```

Differential Revision: [D82871310](https://our.internmc.facebook.com/intern/diff/D82871310)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163209
Approved by: https://github.com/nmacchioni
ghstack dependencies: #163305
2025-09-20 05:30:40 +00:00

100 lines
3.3 KiB
Python

from __future__ import annotations
from typing import Any, Optional, TYPE_CHECKING, Union
from .template_heuristics.params import DictKernelTemplateParams
if TYPE_CHECKING:
from collections.abc import Generator
from .codegen.common import KernelTemplate
from .ir import ChoiceCaller, Layout
from .kernel_inputs import KernelInputs
from .select_algorithm import ExternKernelChoice
from .template_heuristics.params import KernelTemplateParams
class KernelTemplateChoice:
"""
A class that encapsulates all the components needed to create a ChoiceCaller from a template.
This class implements lazy evaluation for the choice property - the actual ChoiceCaller
is only created when first accessed via the choice property.
"""
def __init__(
self,
template: Union[KernelTemplate, ExternKernelChoice],
params: KernelTemplateParams,
extra_kwargs: dict[str, Any],
layout: Layout,
inputs: KernelInputs,
):
self.template = template
self.params = params
self.extra_kwargs = extra_kwargs
self.layout = layout
self.inputs = inputs
self.annotations: dict[str, Any] = {"ktc": self}
@property
def choice(self) -> Optional[ChoiceCaller]:
"""
Lazily evaluate and return the ChoiceCaller for this template choice.
On first access, calls template.choice_or_none() with the stored parameters.
If successful, caches and returns the ChoiceCaller. If it fails, caches
and returns None. Subsequent accesses return the cached value.
Returns:
ChoiceCaller if the template choice succeeds, None otherwise
"""
if not hasattr(self, "_choice"):
# First time accessing choice - try to generate it
kwargs = self.params.to_kwargs()
self._choice = self.template.choice_or_none(
**kwargs,
**self.extra_kwargs,
layout=self.layout,
input_nodes=self.inputs.nodes(),
)
if self._choice is not None:
self._choice.annotations = self.annotations
return self._choice
def make_ktc_generator(
template: Union[KernelTemplate, ExternKernelChoice],
cs: Generator[KernelTemplateParams, None, None],
extra_kwargs: dict[str, Any],
overrides: dict[str, Any],
layout: Layout,
inputs: KernelInputs,
) -> Generator[KernelTemplateChoice, None, None]:
"""
Create a generator of KernelTemplateChoice objects for a given template.
Args:
template: The template object (KernelTemplate or ExternKernelChoice)
cs: Generator of KernelTemplateParams from template heuristic
overrides: Override kwargs for the template
layout: Layout value for the template
inputs: KernelInputs for the op
Yields:
KernelTemplateChoice objects
"""
for params in cs:
# Apply overrides to params
base_kwargs = params.to_kwargs()
final_kwargs = {**base_kwargs, **overrides}
final_params = DictKernelTemplateParams(final_kwargs)
yield KernelTemplateChoice(
template=template,
params=final_params,
extra_kwargs=extra_kwargs,
layout=layout,
inputs=inputs,
)