[inductor][mm] restructure decompose k (#161026)

# why

- make it easier to integrate into lookup table later

# what

- current version generates templates on the fly and uses them
  to generate a single choice
- lookup table and performance model work best when there is a
  stable set of templates (with predictable names) and those
  are then parametrized
- this change makes it so that there is a single DecomposeK template
  with a stable name, and the k split is the only parametrization we do

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py::TestMaxAutotune::test_max_autotune_decompose_k_dynamic_False_bfloat16_sizes1 -v
```

Differential Revision: [D80670913](https://our.internmc.facebook.com/intern/diff/D80670913)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161026
Approved by: https://github.com/PaulZhang12, https://github.com/jansel
This commit is contained in:
Ruben Rodriguez Buchillon
2025-08-27 18:44:20 -07:00
committed by PyTorch MergeBot
parent f0a517e333
commit 688acf0b83
3 changed files with 57 additions and 48 deletions

View File

@ -1,18 +1,13 @@
# Owner(s): ["module: inductor"]
import functools
import unittest
from unittest import mock
from unittest.mock import MagicMock
import torch
from torch._dispatch.python import enable_python_dispatcher
from torch._inductor.codegen.subgraph import SubgraphTemplate
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.ir import Buffer, FixedLayout, FlexibleLayout
from torch._inductor.lowering import register_lowering
from torch._inductor.select_algorithm import autotune_select_algorithm
from torch._inductor.test_case import run_tests, TestCase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
@ -64,20 +59,14 @@ class TestSubgraphChoice(TestCase):
choices = [aten_mm.bind((mat1, mat2), layout)]
kPartitions = 256
with enable_python_dispatcher():
decompositions = select_decomp_table()
decompose_k_subgraph_template = SubgraphTemplate(
name="decompose_k_mm",
make_fx_graph=make_fx(
functools.partial(decomposeK, kPartitions=kPartitions),
decompositions,
tracing_mode="real",
),
)
decompose_k_subgraph_template = (
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
k_split=kPartitions,
input_nodes=(mat1, mat2),
layout=layout,
)
@ -139,19 +128,14 @@ class TestSubgraphChoice(TestCase):
choices = []
kPartitions = 2
with enable_python_dispatcher():
decompositions = select_decomp_table()
decompose_k_subgraph_template = SubgraphTemplate(
name="decompose_k_mm",
make_fx_graph=make_fx(
functools.partial(decomposeK, kPartitions=kPartitions),
decompositions,
),
)
decompose_k_subgraph_template = (
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
k_split=kPartitions,
input_nodes=(mat1, mat2),
layout=layout,
)

View File

@ -168,7 +168,6 @@ class SubgraphTemplate(KernelTemplate):
def __init__(
self,
name: str,
make_fx_graph: Callable[..., Any],
):
"""
Initialize a subgraph template.
@ -177,13 +176,15 @@ class SubgraphTemplate(KernelTemplate):
name: The name of this template
graph: The FX graph
"""
self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
self.make_fx_graph = make_fx_graph
super().__init__(name=name)
def generate( # type: ignore[override]
self,
name: str,
input_nodes: list[Buffer],
layout: Layout,
make_fx_graph: Callable[..., Any],
description: str = "",
**kwargs: Any,
) -> SubgraphChoiceCaller:
"""
@ -200,9 +201,9 @@ class SubgraphTemplate(KernelTemplate):
"""
return SubgraphChoiceCaller(
name=self.name,
name=f"{name}_{next(SubgraphTemplate.index_counter)}",
input_nodes=input_nodes,
layout=layout,
description="",
make_fx_graph=self.make_fx_graph,
description=description,
make_fx_graph=make_fx_graph,
)

View File

@ -24,8 +24,8 @@ from .. import config as inductor_config
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
from ..codegen.subgraph import SubgraphTemplate
from ..ir import FlexibleLayout, is_triton
from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate
from ..ir import Buffer, FlexibleLayout, is_triton, Layout
from ..kernel_inputs import MMKernelInputs
from ..lowering import (
add_layout_constraint,
@ -658,6 +658,44 @@ def decomposeK(a, b, k_splits):
return reduced_buf.to(a.dtype)
class DecomposeKSugraphTemplate(SubgraphTemplate):
def __init__(self):
super().__init__(
name="decompose_k",
)
def generate( # type: ignore[override]
self,
input_nodes: list[Buffer],
layout: Layout,
k_split: int,
) -> SubgraphChoiceCaller:
from torch._dispatch.python import enable_python_dispatcher
from ..decomposition import select_decomp_table
name = f"decompose_k_mm_{k_split}_split"
description = f"{k_split=}"
with enable_python_dispatcher():
decompositions = select_decomp_table()
fn = make_fx(
functools.partial(decomposeK, k_splits=k_split),
decompositions,
)
return super().generate(
name=name,
input_nodes=input_nodes,
layout=layout,
make_fx_graph=fn,
description=description,
)
decompose_k_subgraph_template = DecomposeKSugraphTemplate()
@register_lowering(aten.mm, type_promotion_kind=None)
def tuned_mm(mat1, mat2, *, layout=None):
"""
@ -739,10 +777,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
)
if use_decompose_k_choice(m, n, k) and not unbacked_symbols:
from torch._dispatch.python import enable_python_dispatcher
from ..decomposition import select_decomp_table
k_splits = get_k_splits(m, n, k)
for k_split in k_splits:
if not V.graph.sizevars.statically_known_true(
@ -750,21 +784,11 @@ def tuned_mm(mat1, mat2, *, layout=None):
):
continue
with enable_python_dispatcher():
decompositions = select_decomp_table()
decompose_k_subgraph_template = SubgraphTemplate(
name=f"decompose_k_mm_{k_split}_split",
make_fx_graph=make_fx(
functools.partial(decomposeK, k_splits=k_split),
decompositions,
),
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
input_nodes=kernel_inputs.nodes(),
layout=layout,
k_split=k_split,
)
if (