mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[inductor] move max-autotune logic inside V.choices.get_mm_configs (#161344)
# why - heuristics providers know decide whether to (or which choices to add) in the max-autotune case - enables an eventual override point to gracefully fallback to the standard behavior # what - max-autotune is determined inside V.choices.get_mm_configs because it's mm only right now, we can just do `config.max_autotune or config.max_autotune_gemm` a TODO indicates that this can change in the future when this expands to more templates # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py -v ``` Differential Revision: [D81520573](https://our.internmc.facebook.com/intern/diff/D81520573) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161344 Approved by: https://github.com/jansel ghstack dependencies: #162075, #161340, #161341, #161342, #161343
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							a301dc3b60
						
					
				
				
					commit
					031d79cb51
				
			| @ -31,8 +31,11 @@ from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout | |||||||
| from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm | from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm | ||||||
| from torch._inductor.select_algorithm import ( | from torch._inductor.select_algorithm import ( | ||||||
|     add_feedback_saver, |     add_feedback_saver, | ||||||
|  |     add_preprocessing_fn, | ||||||
|     AlgorithmSelectorCache, |     AlgorithmSelectorCache, | ||||||
|     clear_feedback_savers, |     clear_feedback_savers, | ||||||
|  |     clear_preprocessing_fns, | ||||||
|  |     ExternKernelCaller, | ||||||
|     TritonTemplate, |     TritonTemplate, | ||||||
|     TritonTemplateCaller, |     TritonTemplateCaller, | ||||||
| ) | ) | ||||||
| @ -1900,6 +1903,76 @@ class TestMaxAutotune(TestCase): | |||||||
|             counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0 |             counters["inductor"]["select_algorithm_num_precompilation_exceptions"], 0 | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @parametrize("op", ("mm", "addmm", "bmm", "baddbmm", "mm_plus_mm")) | ||||||
|  |     @parametrize("max_autotune", (False, True)) | ||||||
|  |     @config.patch( | ||||||
|  |         {"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "ATEN,TRITON"} | ||||||
|  |     ) | ||||||
|  |     def test_autotune_gemm_choice_validation(self, op, max_autotune): | ||||||
|  |         def generate_inputs_and_func(op_name): | ||||||
|  |             # Base config with just x and w | ||||||
|  |             base_inputs = [ | ||||||
|  |                 torch.randn(128, 256, device=GPU_TYPE), | ||||||
|  |                 torch.randn(256, 128, device=GPU_TYPE), | ||||||
|  |             ] | ||||||
|  |             func = torch.mm | ||||||
|  |             if op_name == "mm": | ||||||
|  |                 # default | ||||||
|  |                 pass | ||||||
|  |             elif op_name == "addmm": | ||||||
|  |                 # Add bias for addmm | ||||||
|  |                 base_inputs = [torch.randn(128, device=GPU_TYPE)] + base_inputs | ||||||
|  |                 func = torch.addmm | ||||||
|  |             elif op_name in ["bmm", "baddbmm"]: | ||||||
|  |                 # Override for batch dimensions | ||||||
|  |                 base_inputs[0] = torch.randn(4, 128, 256, device=GPU_TYPE) | ||||||
|  |                 base_inputs[1] = torch.randn(4, 256, 128, device=GPU_TYPE) | ||||||
|  |                 func = torch.bmm | ||||||
|  |                 if op_name == "baddbmm": | ||||||
|  |                     # Add batch bias | ||||||
|  |                     base_inputs = [ | ||||||
|  |                         torch.torch.randn(4, 128, 128, device=GPU_TYPE) | ||||||
|  |                     ] + base_inputs | ||||||
|  |                     func = torch.baddbmm | ||||||
|  |             elif op_name == "mm_plus_mm": | ||||||
|  |                 # Add second matrix pair | ||||||
|  |                 base_inputs += [ | ||||||
|  |                     torch.randn(128, 256, device=GPU_TYPE), | ||||||
|  |                     torch.randn(256, 128, device=GPU_TYPE), | ||||||
|  |                 ] | ||||||
|  |  | ||||||
|  |                 def mmpmm(x, w, x2, w2): | ||||||
|  |                     return torch.mm(x, w) + torch.mm(x2, w2) | ||||||
|  |  | ||||||
|  |                 func = mmpmm | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f"Unsupported op: {op_name}") | ||||||
|  |             return base_inputs, func | ||||||
|  |  | ||||||
|  |         choice_types_seen = set() | ||||||
|  |  | ||||||
|  |         def choice_validator(choices): | ||||||
|  |             for choice in choices: | ||||||
|  |                 choice_types_seen.add(type(choice)) | ||||||
|  |             return choices | ||||||
|  |  | ||||||
|  |         inputs, fn = generate_inputs_and_func(op) | ||||||
|  |  | ||||||
|  |         add_preprocessing_fn(choice_validator) | ||||||
|  |         try: | ||||||
|  |             with config.patch({"max_autotune": max_autotune}): | ||||||
|  |                 compiled_fn = torch.compile(fn, dynamic=False) | ||||||
|  |                 compiled_fn(*inputs) | ||||||
|  |  | ||||||
|  |                 if max_autotune: | ||||||
|  |                     self.assertIn(ExternKernelCaller, choice_types_seen) | ||||||
|  |                     self.assertIn(TritonTemplateCaller, choice_types_seen) | ||||||
|  |                 else: | ||||||
|  |                     self.assertIn(ExternKernelCaller, choice_types_seen) | ||||||
|  |                     self.assertNotIn(TritonTemplateCaller, choice_types_seen) | ||||||
|  |         finally: | ||||||
|  |             clear_preprocessing_fns() | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestMaxAutotunePrecompile(TestCase): | class TestMaxAutotunePrecompile(TestCase): | ||||||
|     def test_precompilation_threads(self): |     def test_precompilation_threads(self): | ||||||
|  | |||||||
| @ -207,7 +207,7 @@ def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): | |||||||
|                 **extra_kwargs, |                 **extra_kwargs, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     if use_triton_template(layout): |     if use_triton_template(layout, check_max_autotune=False): | ||||||
|         # TODO: add out_dtype support for Triton Template |         # TODO: add out_dtype support for Triton Template | ||||||
|         assert out_dtype is None, "out_dtype is not supported for Triton" |         assert out_dtype is None, "out_dtype is not supported for Triton" | ||||||
|  |  | ||||||
| @ -288,7 +288,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): | |||||||
|                 **extra_kwargs, |                 **extra_kwargs, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     if use_triton_template(layout): |     if use_triton_template(layout, check_max_autotune=False): | ||||||
|         for kwargs, extra_kwargs in V.choices.get_mm_configs( |         for kwargs, extra_kwargs in V.choices.get_mm_configs( | ||||||
|             kernel_inputs, |             kernel_inputs, | ||||||
|             layout, |             layout, | ||||||
|  | |||||||
| @ -762,7 +762,7 @@ def tuned_mm(mat1, mat2, *, layout=None): | |||||||
|             ) |             ) | ||||||
|     static_shape, is_nonzero = _is_static_problem(layout) |     static_shape, is_nonzero = _is_static_problem(layout) | ||||||
|  |  | ||||||
|     if is_nonzero and use_triton_template(layout): |     if is_nonzero and use_triton_template(layout, check_max_autotune=False): | ||||||
|         # Get template params using the new unified function |         # Get template params using the new unified function | ||||||
|         for kwargs, extra_kwargs in V.choices.get_mm_configs( |         for kwargs, extra_kwargs in V.choices.get_mm_configs( | ||||||
|             kernel_inputs, layout, mm_template, "mm" |             kernel_inputs, layout, mm_template, "mm" | ||||||
| @ -941,7 +941,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None): | |||||||
|             choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True |             choices, layout, kernel_inputs.nodes(), fuseable=True, non_fuseable=True | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     if is_nonzero and use_triton_template(layout, enable_int32=True): |     if is_nonzero and use_triton_template( | ||||||
|  |         layout, enable_int32=True, check_max_autotune=False | ||||||
|  |     ): | ||||||
|         for kwargs, extra_kwargs in V.choices.get_mm_configs( |         for kwargs, extra_kwargs in V.choices.get_mm_configs( | ||||||
|             kernel_inputs, layout, mm_template, name |             kernel_inputs, layout, mm_template, name | ||||||
|         ): |         ): | ||||||
| @ -1035,7 +1037,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): | |||||||
|                 **extra_kwargs, |                 **extra_kwargs, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     if is_nonzero and use_triton_template(layout): |     if is_nonzero and use_triton_template(layout, check_max_autotune=False): | ||||||
|         # all the triton templates use the extra_kwargs |         # all the triton templates use the extra_kwargs | ||||||
|         # Get template params using the new unified function |         # Get template params using the new unified function | ||||||
|         for kwargs, extra_kwargs in V.choices.get_mm_configs( |         for kwargs, extra_kwargs in V.choices.get_mm_configs( | ||||||
| @ -1248,7 +1250,9 @@ def tuned_scaled_mm( | |||||||
|  |  | ||||||
|     _, is_nonzero = _is_static_problem(layout) |     _, is_nonzero = _is_static_problem(layout) | ||||||
|  |  | ||||||
|     if is_nonzero and use_triton_template(layout, enable_float8=True): |     if is_nonzero and use_triton_template( | ||||||
|  |         layout, enable_float8=True, check_max_autotune=False | ||||||
|  |     ): | ||||||
|         overriders = dict(USE_FAST_ACCUM=use_fast_accum) |         overriders = dict(USE_FAST_ACCUM=use_fast_accum) | ||||||
|         # TODO (paulzhan): There is no template that exists for bias and TMA |         # TODO (paulzhan): There is no template that exists for bias and TMA | ||||||
|         # Don't run tma template currently if bias exists |         # Don't run tma template currently if bias exists | ||||||
|  | |||||||
| @ -165,7 +165,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): | |||||||
|                 **extra_kwargs, |                 **extra_kwargs, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     if use_triton_template(layout1): |     if use_triton_template(layout1, check_max_autotune=False): | ||||||
|         # Get template params using the new unified function |         # Get template params using the new unified function | ||||||
|         for kwargs, extra_kwargs in V.choices.get_mm_configs( |         for kwargs, extra_kwargs in V.choices.get_mm_configs( | ||||||
|             kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm" |             kernel_inputs, layout1, mm_plus_mm_template, "mm_plus_mm" | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ from ..kernel.bmm import aten_baddbmm, aten_bmm, aten_bmm_dtype | |||||||
| from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm | from ..kernel.mm import aten__fp8_mm, aten__int_mm, aten_addmm, aten_bias_addmm, aten_mm | ||||||
| from ..kernel.mm_plus_mm import aten_mm_plus_mm | from ..kernel.mm_plus_mm import aten_mm_plus_mm | ||||||
| from .base import TemplateConfigHeuristics | from .base import TemplateConfigHeuristics | ||||||
|  | from .gemm import GemmMaxAutotuneTemplateConfigHeuristics | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
| @ -37,7 +38,7 @@ class ATenConfigHeuristics(TemplateConfigHeuristics): | |||||||
|     If you want to use this with an ATen choice that has kwargs, just subclass |     If you want to use this with an ATen choice that has kwargs, just subclass | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Layout, |         layout: Layout, | ||||||
| @ -68,17 +69,15 @@ class ATenAddMMConfigHeuristics(ATenConfigHeuristics): | |||||||
|  |  | ||||||
|  |  | ||||||
| @register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm") | @register_template_heuristic(aten_bias_addmm.uid, None, op_name="addmm") | ||||||
| class ATenBiasAddMMConfigHeuristics(ATenAddMMConfigHeuristics): | class ATenBiasAddMMConfigHeuristics( | ||||||
|     def get_template_configs( |     ATenAddMMConfigHeuristics, GemmMaxAutotuneTemplateConfigHeuristics | ||||||
|  | ): | ||||||
|  |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Layout, |         layout: Layout, | ||||||
|         op_name: str, |         op_name: str, | ||||||
|     ) -> Generator[dict[str, Any], None, None]: |     ) -> Generator[dict[str, Any], None, None]: | ||||||
|         if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm): |  | ||||||
|             # NOTE: this preserves the original logic that if there is not max-autotune |  | ||||||
|             # then we skip bias_addmm |  | ||||||
|             return |  | ||||||
|         nodes = kernel_inputs.nodes() |         nodes = kernel_inputs.nodes() | ||||||
|         # for addmm, bias is the first input |         # for addmm, bias is the first input | ||||||
|         bias = nodes[0] |         bias = nodes[0] | ||||||
|  | |||||||
| @ -11,6 +11,20 @@ if TYPE_CHECKING: | |||||||
|  |  | ||||||
|  |  | ||||||
| class TemplateConfigHeuristics: | class TemplateConfigHeuristics: | ||||||
|  |     """Base class for generating sets of configs for an associated template.""" | ||||||
|  |  | ||||||
|  |     def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: | ||||||
|  |         """ | ||||||
|  |         hookup to check whether the configs are right to run at all e.g. you can check | ||||||
|  |         max-autotune specific to your heuristic here or other things | ||||||
|  |         If this returns False, get_template_configs will yield no configs | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             inputs: KernelInputs | ||||||
|  |             layout: Layout | ||||||
|  |         """ | ||||||
|  |         return True | ||||||
|  |  | ||||||
|     def get_template_configs( |     def get_template_configs( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
| @ -19,10 +33,30 @@ class TemplateConfigHeuristics: | |||||||
|     ) -> Generator[dict[str, Any], None, None]: |     ) -> Generator[dict[str, Any], None, None]: | ||||||
|         """ |         """ | ||||||
|         Get template configs for the given inputs. |         Get template configs for the given inputs. | ||||||
|  |  | ||||||
|  |         Prefer to override the _get_template_configs_impl method | ||||||
|  |         to leverage things like should_run | ||||||
|  |         """ | ||||||
|  |         if not self.should_run(kernel_inputs, layout): | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         yield from self._get_template_configs_impl( | ||||||
|  |             kernel_inputs, | ||||||
|  |             layout, | ||||||
|  |             op_name, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _get_template_configs_impl( | ||||||
|  |         self, | ||||||
|  |         kernel_inputs: KernelInputs, | ||||||
|  |         layout: Layout, | ||||||
|  |         op_name: str, | ||||||
|  |     ) -> Generator[dict[str, Any], None, None]: | ||||||
|  |         """ | ||||||
|  |         Get template configs for the given inputs. | ||||||
|         This is the main entry point for template-specific logic. |         This is the main entry point for template-specific logic. | ||||||
|         """ |         """ | ||||||
|         # NOTE: not an abstract class, because that clashed below for the mixin |         # base implementation yields no entries | ||||||
|         # functionality. Can be adjusted, but not a high priority |  | ||||||
|         yield from [] |         yield from [] | ||||||
|  |  | ||||||
|     def get_extra_kwargs( |     def get_extra_kwargs( | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ from ..kernel.mm import ( | |||||||
| from ..kernel_inputs import KernelInputs, MMKernelInputs | from ..kernel_inputs import KernelInputs, MMKernelInputs | ||||||
| from ..utils import use_contiguous | from ..utils import use_contiguous | ||||||
| from .base import TemplateConfigHeuristics | from .base import TemplateConfigHeuristics | ||||||
|  | from .gemm import GemmMaxAutotuneTemplateConfigHeuristics | ||||||
| from .registry import register_template_heuristic | from .registry import register_template_heuristic | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -41,8 +42,8 @@ class EmptyContiguousMMConfigHeuristics(TemplateConfigHeuristics): | |||||||
|     register=torch.version.hip is not None, |     register=torch.version.hip is not None, | ||||||
|     op_name="addmm", |     op_name="addmm", | ||||||
| ) | ) | ||||||
| class ContiguousMMHeuristics(TemplateConfigHeuristics): | class ContiguousMMHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Layout, |         layout: Layout, | ||||||
| @ -54,7 +55,6 @@ class ContiguousMMHeuristics(TemplateConfigHeuristics): | |||||||
|         assert isinstance(kernel_inputs, MMKernelInputs), ( |         assert isinstance(kernel_inputs, MMKernelInputs), ( | ||||||
|             f"{self.__class__.__name__} requires MMKernelInputs" |             f"{self.__class__.__name__} requires MMKernelInputs" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # Check for unbacked symbols - if found, yield nothing |         # Check for unbacked symbols - if found, yield nothing | ||||||
|         unbacked_symbols = any( |         unbacked_symbols = any( | ||||||
|             len(get_free_symbols(itr, unbacked_only=True)) > 0 |             len(get_free_symbols(itr, unbacked_only=True)) > 0 | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ from ..kernel_inputs import KernelInputs, MMKernelInputs | |||||||
| from ..utils import get_k_splits | from ..utils import get_k_splits | ||||||
| from ..virtualized import V | from ..virtualized import V | ||||||
| from .base import TemplateConfigHeuristics | from .base import TemplateConfigHeuristics | ||||||
|  | from .gemm import GemmMaxAutotuneTemplateConfigHeuristics | ||||||
| from .registry import register_template_heuristic | from .registry import register_template_heuristic | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -38,8 +39,8 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): | |||||||
| # TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) | # TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) | ||||||
| # by either adding specific register_template_heuristic tags, or setting the | # by either adding specific register_template_heuristic tags, or setting the | ||||||
| # device to None (enabled on all devices) | # device to None (enabled on all devices) | ||||||
| class DecomposeKConfigHeuristics(TemplateConfigHeuristics): | class DecomposeKConfigHeuristics(GemmMaxAutotuneTemplateConfigHeuristics): | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Layout, |         layout: Layout, | ||||||
|  | |||||||
							
								
								
									
										19
									
								
								torch/_inductor/template_heuristics/gemm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								torch/_inductor/template_heuristics/gemm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | |||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | from .. import config as inductor_config | ||||||
|  | from .base import TemplateConfigHeuristics | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from ..ir import Layout | ||||||
|  |     from ..kernel_inputs import KernelInputs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class GemmMaxAutotuneTemplateConfigHeuristics(TemplateConfigHeuristics): | ||||||
|  |     def should_run(self, inputs: KernelInputs, layout: Layout) -> bool: | ||||||
|  |         """ | ||||||
|  |         simple base override for GEMM family templates that run only in max-autotune | ||||||
|  |         """ | ||||||
|  |         return inductor_config.max_autotune or inductor_config.max_autotune_gemm | ||||||
| @ -32,7 +32,7 @@ from ..utils import ( | |||||||
|     using_b200, |     using_b200, | ||||||
| ) | ) | ||||||
| from ..virtualized import V | from ..virtualized import V | ||||||
| from .base import TemplateConfigHeuristics | from .gemm import GemmMaxAutotuneTemplateConfigHeuristics | ||||||
| from .registry import register_template_heuristic | from .registry import register_template_heuristic | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -1417,7 +1417,7 @@ class MTIAConfigHeuristic(BaseConfigHeuristic): | |||||||
|  |  | ||||||
|  |  | ||||||
| # Template-specific mixin classes | # Template-specific mixin classes | ||||||
| class MMTemplateConfigMixin(TemplateConfigHeuristics): | class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics): | ||||||
|     """ |     """ | ||||||
|     Mixin class that converts config lists to template kwargs. |     Mixin class that converts config lists to template kwargs. | ||||||
|     This handles the logic that was previously in choices.get_mm_configs. |     This handles the logic that was previously in choices.get_mm_configs. | ||||||
| @ -1448,7 +1448,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics): | |||||||
|         else: |         else: | ||||||
|             return self.get_mm_configs() |             return self.get_mm_configs() | ||||||
|  |  | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Any, |         layout: Any, | ||||||
| @ -1466,6 +1466,7 @@ class MMTemplateConfigMixin(TemplateConfigHeuristics): | |||||||
|             raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") |             raise ValueError(f"Need at least 2 input tensors, got {len(input_nodes)}") | ||||||
|         if not self._valid(kernel_inputs): |         if not self._valid(kernel_inputs): | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         # Extract M, N, K from kernel_inputs |         # Extract M, N, K from kernel_inputs | ||||||
|         m, n, k = kernel_inputs.mnk_symbolic() |         m, n, k = kernel_inputs.mnk_symbolic() | ||||||
|  |  | ||||||
| @ -1593,7 +1594,7 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): | |||||||
|     This inherits from MMTemplateConfigMixin and overrides config generation. |     This inherits from MMTemplateConfigMixin and overrides config generation. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Any, |         layout: Any, | ||||||
| @ -1614,8 +1615,10 @@ class TMATemplateConfigMixin(TMAWorkspaceMixin, MMTemplateConfigMixin): | |||||||
|             "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), |             "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), | ||||||
|         } |         } | ||||||
|         # Get base template configs from superclass |         # Get base template configs from superclass | ||||||
|         for template_kwargs in super().get_template_configs( |         for template_kwargs in super()._get_template_configs_impl( | ||||||
|             kernel_inputs, layout, op_name |             kernel_inputs, | ||||||
|  |             layout, | ||||||
|  |             op_name, | ||||||
|         ): |         ): | ||||||
|             yield {**template_kwargs, **tma_opts} |             yield {**template_kwargs, **tma_opts} | ||||||
|  |  | ||||||
| @ -1661,7 +1664,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin): | |||||||
|             nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx |             nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Any, |         layout: Any, | ||||||
| @ -1713,7 +1716,7 @@ class BaseScaledMMConfigMixin(MMTemplateConfigMixin): | |||||||
|             return |             return | ||||||
|  |  | ||||||
|         # Get base template configs from superclass |         # Get base template configs from superclass | ||||||
|         for template_kwargs in super().get_template_configs( |         for template_kwargs in super()._get_template_configs_impl( | ||||||
|             kernel_inputs, layout, op_name |             kernel_inputs, layout, op_name | ||||||
|         ): |         ): | ||||||
|             # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) |             # Add scaled MM-specific options (moved from mm_common.scaled_mm_options) | ||||||
| @ -1769,7 +1772,7 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): | |||||||
|     This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. |     This inherits from BaseScaledMMConfigMixin and adds TMA-specific options. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def get_template_configs( |     def _get_template_configs_impl( | ||||||
|         self, |         self, | ||||||
|         kernel_inputs: KernelInputs, |         kernel_inputs: KernelInputs, | ||||||
|         layout: Any, |         layout: Any, | ||||||
| @ -1779,8 +1782,10 @@ class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): | |||||||
|         Generate scaled TMA template configs with both scaled MM and TMA-specific options. |         Generate scaled TMA template configs with both scaled MM and TMA-specific options. | ||||||
|         """ |         """ | ||||||
|         # Get base scaled MM template configs from superclass |         # Get base scaled MM template configs from superclass | ||||||
|         for template_kwargs in super().get_template_configs( |         for template_kwargs in super()._get_template_configs_impl( | ||||||
|             kernel_inputs, layout, op_name |             kernel_inputs, | ||||||
|  |             layout, | ||||||
|  |             op_name, | ||||||
|         ): |         ): | ||||||
|             # Add TMA-specific options for device TMA scaled MM |             # Add TMA-specific options for device TMA scaled MM | ||||||
|             template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE |             template_kwargs["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE | ||||||
|  | |||||||
| @ -1634,7 +1634,11 @@ def _use_conv_autotune_backend(backend: str) -> bool: | |||||||
|  |  | ||||||
|  |  | ||||||
| def use_triton_template( | def use_triton_template( | ||||||
|     layout: Layout, *, enable_int32: bool = False, enable_float8: bool = False |     layout: Layout, | ||||||
|  |     *, | ||||||
|  |     enable_int32: bool = False, | ||||||
|  |     enable_float8: bool = False, | ||||||
|  |     check_max_autotune: bool = True, | ||||||
| ) -> bool: | ) -> bool: | ||||||
|     from .codegen.common import BackendFeature, has_backend_feature |     from .codegen.common import BackendFeature, has_backend_feature | ||||||
|  |  | ||||||
| @ -1651,7 +1655,8 @@ def use_triton_template( | |||||||
|             ) |             ) | ||||||
|             or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) |             or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) | ||||||
|         ) |         ) | ||||||
|         and (config.max_autotune or config.max_autotune_gemm) |         # some callers handle max-autotune checking externally | ||||||
|  |         and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune) | ||||||
|         and _use_autotune_backend("TRITON") |         and _use_autotune_backend("TRITON") | ||||||
|         and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) |         and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) | ||||||
|     ) |     ) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user