Files
pytorch/test/inductor/test_subgraph_choice.py
Ruben Rodriguez Buchillon 688acf0b83 [inductor][mm] restructure decompose k (#161026)
# why

- make it easier to integrate into lookup table later

# what

- current version generates templates on the fly and uses them
  to generate a single choice
- lookup table and performance model work best when there is a
  stable set of templates (with predictable names) and those
  are then parametrized
- this change makes it so that there is a single DecomposeK template
  with a stable name, and the k split is the only parametrization we do

# testing

```
python3 -bb -m pytest test/inductor/test_max_autotune.py::TestMaxAutotune::test_max_autotune_decompose_k_dynamic_False_bfloat16_sizes1 -v
```

Differential Revision: [D80670913](https://our.internmc.facebook.com/intern/diff/D80670913)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161026
Approved by: https://github.com/PaulZhang12, https://github.com/jansel
2025-08-28 20:14:41 +00:00

177 lines
5.9 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from unittest import mock
from unittest.mock import MagicMock
import torch
from torch._inductor.ir import Buffer, FixedLayout, FlexibleLayout
from torch._inductor.lowering import register_lowering
from torch._inductor.select_algorithm import autotune_select_algorithm
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
def decomposeK(a, b, kPartitions):
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
B = k // kPartitions
a_reshaped = torch.permute(a.reshape(m, B, kPartitions), (1, 0, 2))
b_reshaped = b.reshape(B, kPartitions, n)
result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32)
result_fp32 = result.to(torch.float32)
reduced_buf = torch.sum(result_fp32, 0)
return reduced_buf.to(a.dtype)
class TestSubgraphChoice(TestCase):
def setUp(self):
super().setUp()
def _create_buffer(self, name, shape, dtype):
return Buffer(
name=name,
layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape),
)
@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
def test_subgraph_decompose_k(self):
from torch._inductor.kernel.mm import aten_mm
from torch._inductor.kernel.mm_common import mm_args
mat1_shape, mat2_shape = (32, 4096), (4096, 32)
@torch.library.custom_op("mylib::matmul_decompose", mutates_args={})
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a @ b
@matmul_decompose.register_fake
def _(a, b):
return a @ b
@register_lowering(torch.ops.mylib.matmul_decompose)
def _(a, b):
_, _, _, layout, mat1, mat2 = mm_args(a, b)
choices = [aten_mm.bind((mat1, mat2), layout)]
kPartitions = 256
decompose_k_subgraph_template = (
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
k_split=kPartitions,
input_nodes=(mat1, mat2),
layout=layout,
)
# Test benchmarking against aten
autotune_select_algorithm("test_subgraph_choice", choices, [a, b], layout)
# Only return decomposeK case for codegen
choices = [choices[1]]
return autotune_select_algorithm(
"test_subgraph_choice", choices, [a, b], layout
)
a_in = torch.randn(
mat1_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
)
b_in = torch.randn(
mat2_shape, dtype=torch.float16, device=torch.device(f"{GPU_TYPE}:0")
)
def func(mat1, mat2):
return torch.ops.mylib.matmul_decompose(mat1, mat2)
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
res = compiled_func(a_in, b_in)
# Check same results of compiled result and regular torch.mm
torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1)
@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
def test_subgraph_freeze_layout(self):
from torch._inductor.kernel.mm_common import mm_args
M, N, K = (4, 128, 14240)
a_in = torch.randn(
(M, K), dtype=torch.bfloat16, device=torch.device(f"{GPU_TYPE}:0")
)
b_in = torch.randn(
(K, N), dtype=torch.bfloat16, device=torch.device(f"{GPU_TYPE}:0")
)
@torch.library.custom_op("mylib::matmul_decompose_padding", mutates_args={})
def matmul_decompose(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a @ b
@matmul_decompose.register_fake
def _(a, b):
return a @ b
@register_lowering(torch.ops.mylib.matmul_decompose_padding)
def _(a, b):
_, _, _, layout, mat1, mat2 = mm_args(a, b)
mat1_layout = mat1.layout
assert isinstance(mat1_layout, FlexibleLayout)
mat1_stride = mat1_layout.stride
choices = []
kPartitions = 2
decompose_k_subgraph_template = (
torch._inductor.kernel.mm.DecomposeKSugraphTemplate()
)
decompose_k_subgraph_template.maybe_append_choice(
choices,
k_split=kPartitions,
input_nodes=(mat1, mat2),
layout=layout,
)
choice = choices[0]
assert isinstance(mat1.layout, FixedLayout)
# Creating the subgraph choice should have frozen the layout
# We ensure padding so the stride should differ
assert mat1.layout.stride != mat1_stride
for example_stride, layout_stride in zip(
choice.example_inputs[0].stride(), mat1.layout.stride
):
# Example inputs should have same stride as current layout
assert example_stride == layout_stride
return autotune_select_algorithm(
"test_subgraph_choice", choices, [a, b], layout
)
def func(mat1, mat2):
return torch.ops.mylib.matmul_decompose_padding((mat1 + 1.0), mat2)
with mock.patch("torch._inductor.ir.V.get_current_node") as get_node_mock:
node_mock = MagicMock()
node_mock.meta = {"dislike_padding": False}
get_node_mock.return_value = node_mock
compiled_func = torch.compile(func, mode="max-autotune", dynamic=False)
compiled_func(a_in, b_in)
if __name__ == "__main__":
# Set env to make it work in CI.
if HAS_GPU and HAS_CPU:
run_tests()