mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor][choices] move extra kwargs out of get_template_configs (#163209)
# why - extra kwargs are input/op dependent and not config dependent. We don't plan to serialize/deserialize them, and so they need to be fed in later beore making the KTC, rather than when getting the config values directly # what - move extra_kwargs into the KTC and get_ktc interface directly # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v -k "_addmm" ``` Differential Revision: [D82871310](https://our.internmc.facebook.com/intern/diff/D82871310) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163209 Approved by: https://github.com/nmacchioni ghstack dependencies: #163305
This commit is contained in:
committed by
PyTorch MergeBot
parent
df5d6d57c9
commit
0ee331b523
@ -166,11 +166,13 @@ class InductorChoices:
|
||||
# adjust the kernel inputs to the template-specific heuristic, if needed
|
||||
# default here is to just return the kernel_inputs as is
|
||||
inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
|
||||
extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
|
||||
# Create KernelTemplateChoice generator using the moved function
|
||||
overrides = kwarg_overrides or {}
|
||||
return make_ktc_generator(
|
||||
template=template,
|
||||
cs=cs,
|
||||
extra_kwargs=extra_kwargs,
|
||||
overrides=overrides,
|
||||
layout=kernel_inputs.output_layout(),
|
||||
inputs=inputs_val,
|
||||
|
@ -27,11 +27,13 @@ class KernelTemplateChoice:
|
||||
self,
|
||||
template: Union[KernelTemplate, ExternKernelChoice],
|
||||
params: KernelTemplateParams,
|
||||
extra_kwargs: dict[str, Any],
|
||||
layout: Layout,
|
||||
inputs: KernelInputs,
|
||||
):
|
||||
self.template = template
|
||||
self.params = params
|
||||
self.extra_kwargs = extra_kwargs
|
||||
self.layout = layout
|
||||
self.inputs = inputs
|
||||
self.annotations: dict[str, Any] = {"ktc": self}
|
||||
@ -53,6 +55,7 @@ class KernelTemplateChoice:
|
||||
kwargs = self.params.to_kwargs()
|
||||
self._choice = self.template.choice_or_none(
|
||||
**kwargs,
|
||||
**self.extra_kwargs,
|
||||
layout=self.layout,
|
||||
input_nodes=self.inputs.nodes(),
|
||||
)
|
||||
@ -64,6 +67,7 @@ class KernelTemplateChoice:
|
||||
def make_ktc_generator(
|
||||
template: Union[KernelTemplate, ExternKernelChoice],
|
||||
cs: Generator[KernelTemplateParams, None, None],
|
||||
extra_kwargs: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
layout: Layout,
|
||||
inputs: KernelInputs,
|
||||
@ -86,10 +90,10 @@ def make_ktc_generator(
|
||||
base_kwargs = params.to_kwargs()
|
||||
final_kwargs = {**base_kwargs, **overrides}
|
||||
final_params = DictKernelTemplateParams(final_kwargs)
|
||||
|
||||
yield KernelTemplateChoice(
|
||||
template=template,
|
||||
params=final_params,
|
||||
extra_kwargs=extra_kwargs,
|
||||
layout=layout,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
@ -39,14 +39,10 @@ class TemplateConfigHeuristics:
|
||||
if not self.should_run(kernel_inputs):
|
||||
return
|
||||
|
||||
# Get extra kwargs once
|
||||
extra_kwargs = self.get_extra_kwargs(kernel_inputs, op_name)
|
||||
|
||||
# Generate configs and fuse with extra_kwargs
|
||||
for config_dict in self._get_template_configs_impl(kernel_inputs, op_name):
|
||||
# Fuse extra_kwargs into config
|
||||
fused_kwargs = {**config_dict, **extra_kwargs}
|
||||
yield DictKernelTemplateParams(fused_kwargs)
|
||||
yield DictKernelTemplateParams(config_dict)
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user