diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 2cc0a28f822a..469afbf02e52 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -166,11 +166,13 @@ class InductorChoices: # adjust the kernel inputs to the template-specific heuristic, if needed # default here is to just return the kernel_inputs as is inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name) + extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name) # Create KernelTemplateChoice generator using the moved function overrides = kwarg_overrides or {} return make_ktc_generator( template=template, cs=cs, + extra_kwargs=extra_kwargs, overrides=overrides, layout=kernel_inputs.output_layout(), inputs=inputs_val, diff --git a/torch/_inductor/kernel_template_choice.py b/torch/_inductor/kernel_template_choice.py index c03783959ccb..8f90157c6c1a 100644 --- a/torch/_inductor/kernel_template_choice.py +++ b/torch/_inductor/kernel_template_choice.py @@ -27,11 +27,13 @@ class KernelTemplateChoice: self, template: Union[KernelTemplate, ExternKernelChoice], params: KernelTemplateParams, + extra_kwargs: dict[str, Any], layout: Layout, inputs: KernelInputs, ): self.template = template self.params = params + self.extra_kwargs = extra_kwargs self.layout = layout self.inputs = inputs self.annotations: dict[str, Any] = {"ktc": self} @@ -53,6 +55,7 @@ class KernelTemplateChoice: kwargs = self.params.to_kwargs() self._choice = self.template.choice_or_none( **kwargs, + **self.extra_kwargs, layout=self.layout, input_nodes=self.inputs.nodes(), ) @@ -64,6 +67,7 @@ class KernelTemplateChoice: def make_ktc_generator( template: Union[KernelTemplate, ExternKernelChoice], cs: Generator[KernelTemplateParams, None, None], + extra_kwargs: dict[str, Any], overrides: dict[str, Any], layout: Layout, inputs: KernelInputs, @@ -86,10 +90,10 @@ def make_ktc_generator( base_kwargs = params.to_kwargs() final_kwargs = {**base_kwargs, **overrides} final_params = DictKernelTemplateParams(final_kwargs) - yield KernelTemplateChoice( template=template, params=final_params, + extra_kwargs=extra_kwargs, layout=layout, inputs=inputs, ) diff --git a/torch/_inductor/template_heuristics/base.py b/torch/_inductor/template_heuristics/base.py index d11ca64fd1aa..0343270f3a11 100644 --- a/torch/_inductor/template_heuristics/base.py +++ b/torch/_inductor/template_heuristics/base.py @@ -39,14 +39,10 @@ class TemplateConfigHeuristics: if not self.should_run(kernel_inputs): return - # Get extra kwargs once - extra_kwargs = self.get_extra_kwargs(kernel_inputs, op_name) - # Generate configs and fuse with extra_kwargs for config_dict in self._get_template_configs_impl(kernel_inputs, op_name): # Fuse extra_kwargs into config - fused_kwargs = {**config_dict, **extra_kwargs} - yield DictKernelTemplateParams(fused_kwargs) + yield DictKernelTemplateParams(config_dict) def _get_template_configs_impl( self,