mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor][mm] restructure decompose k (#161026)
# why - make it easier to integrate into lookup table later # what - current version generates templates on the fly and uses them to generate a single choice - lookup table and performance model work best when there is a stable set of templates (with predictable names) and those are then parametrized - this change makes it so that there is a single DecomposeK template with a stable name, and the k split is the only parametrization we do # testing ``` python3 -bb -m pytest test/inductor/test_max_autotune.py::TestMaxAutotune::test_max_autotune_decompose_k_dynamic_False_bfloat16_sizes1 -v ``` Differential Revision: [D80670913](https://our.internmc.facebook.com/intern/diff/D80670913) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161026 Approved by: https://github.com/PaulZhang12, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0a517e333
commit
688acf0b83
@ -1,18 +1,13 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import functools
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._inductor.codegen.subgraph import SubgraphTemplate
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch._inductor.ir import Buffer, FixedLayout, FlexibleLayout
|
||||
from torch._inductor.lowering import register_lowering
|
||||
from torch._inductor.select_algorithm import autotune_select_algorithm
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||
|
||||
@ -64,20 +59,14 @@ class TestSubgraphChoice(TestCase):
|
||||
choices = [aten_mm.bind((mat1, mat2), layout)]
|
||||
|
||||
kPartitions = 256
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
|
||||
decompose_k_subgraph_template = SubgraphTemplate(
|
||||
name="decompose_k_mm",
|
||||
make_fx_graph=make_fx(
|
||||
functools.partial(decomposeK, kPartitions=kPartitions),
|
||||
decompositions,
|
||||
tracing_mode="real",
|
||||
),
|
||||
)
|
||||
decompose_k_subgraph_template = (
|
||||
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
|
||||
)
|
||||
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
k_split=kPartitions,
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
)
|
||||
@ -139,19 +128,14 @@ class TestSubgraphChoice(TestCase):
|
||||
choices = []
|
||||
|
||||
kPartitions = 2
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
|
||||
decompose_k_subgraph_template = SubgraphTemplate(
|
||||
name="decompose_k_mm",
|
||||
make_fx_graph=make_fx(
|
||||
functools.partial(decomposeK, kPartitions=kPartitions),
|
||||
decompositions,
|
||||
),
|
||||
)
|
||||
decompose_k_subgraph_template = (
|
||||
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
|
||||
)
|
||||
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
k_split=kPartitions,
|
||||
input_nodes=(mat1, mat2),
|
||||
layout=layout,
|
||||
)
|
||||
|
@ -168,7 +168,6 @@ class SubgraphTemplate(KernelTemplate):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
make_fx_graph: Callable[..., Any],
|
||||
):
|
||||
"""
|
||||
Initialize a subgraph template.
|
||||
@ -177,13 +176,15 @@ class SubgraphTemplate(KernelTemplate):
|
||||
name: The name of this template
|
||||
graph: The FX graph
|
||||
"""
|
||||
self.name = f"{name}_{next(SubgraphTemplate.index_counter)}"
|
||||
self.make_fx_graph = make_fx_graph
|
||||
super().__init__(name=name)
|
||||
|
||||
def generate( # type: ignore[override]
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
make_fx_graph: Callable[..., Any],
|
||||
description: str = "",
|
||||
**kwargs: Any,
|
||||
) -> SubgraphChoiceCaller:
|
||||
"""
|
||||
@ -200,9 +201,9 @@ class SubgraphTemplate(KernelTemplate):
|
||||
"""
|
||||
|
||||
return SubgraphChoiceCaller(
|
||||
name=self.name,
|
||||
name=f"{name}_{next(SubgraphTemplate.index_counter)}",
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
description="",
|
||||
make_fx_graph=self.make_fx_graph,
|
||||
description=description,
|
||||
make_fx_graph=make_fx_graph,
|
||||
)
|
||||
|
@ -24,8 +24,8 @@ from .. import config as inductor_config
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
from ..codegen.subgraph import SubgraphTemplate
|
||||
from ..ir import FlexibleLayout, is_triton
|
||||
from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate
|
||||
from ..ir import Buffer, FlexibleLayout, is_triton, Layout
|
||||
from ..kernel_inputs import MMKernelInputs
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
@ -658,6 +658,44 @@ def decomposeK(a, b, k_splits):
|
||||
return reduced_buf.to(a.dtype)
|
||||
|
||||
|
||||
class DecomposeKSugraphTemplate(SubgraphTemplate):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="decompose_k",
|
||||
)
|
||||
|
||||
def generate( # type: ignore[override]
|
||||
self,
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
k_split: int,
|
||||
) -> SubgraphChoiceCaller:
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from ..decomposition import select_decomp_table
|
||||
|
||||
name = f"decompose_k_mm_{k_split}_split"
|
||||
description = f"{k_split=}"
|
||||
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
fn = make_fx(
|
||||
functools.partial(decomposeK, k_splits=k_split),
|
||||
decompositions,
|
||||
)
|
||||
|
||||
return super().generate(
|
||||
name=name,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
make_fx_graph=fn,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
decompose_k_subgraph_template = DecomposeKSugraphTemplate()
|
||||
|
||||
|
||||
@register_lowering(aten.mm, type_promotion_kind=None)
|
||||
def tuned_mm(mat1, mat2, *, layout=None):
|
||||
"""
|
||||
@ -739,10 +777,6 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
)
|
||||
)
|
||||
if use_decompose_k_choice(m, n, k) and not unbacked_symbols:
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
from ..decomposition import select_decomp_table
|
||||
|
||||
k_splits = get_k_splits(m, n, k)
|
||||
for k_split in k_splits:
|
||||
if not V.graph.sizevars.statically_known_true(
|
||||
@ -750,21 +784,11 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
):
|
||||
continue
|
||||
|
||||
with enable_python_dispatcher():
|
||||
decompositions = select_decomp_table()
|
||||
|
||||
decompose_k_subgraph_template = SubgraphTemplate(
|
||||
name=f"decompose_k_mm_{k_split}_split",
|
||||
make_fx_graph=make_fx(
|
||||
functools.partial(decomposeK, k_splits=k_split),
|
||||
decompositions,
|
||||
),
|
||||
)
|
||||
|
||||
decompose_k_subgraph_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(mat1, mat2),
|
||||
input_nodes=kernel_inputs.nodes(),
|
||||
layout=layout,
|
||||
k_split=k_split,
|
||||
)
|
||||
|
||||
if (
|
||||
|
Reference in New Issue
Block a user