[Inductor] External callable registration API for Matmul tuning candidates (#130774)

Fixes #[130769](https://github.com/pytorch/pytorch/issues/130769)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130774
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
This commit is contained in:
Max Hu
2024-10-02 15:38:10 +00:00
committed by PyTorch MergeBot
parent af86a6fdba
commit a954a9ea75
3 changed files with 107 additions and 0 deletions

View File

@ -0,0 +1,94 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_cuda import TEST_CUDA
class MatMulModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.matrix = torch.nn.Parameter(torch.eye(128, 128) * 2, requires_grad=True)
def forward(self, x):
return torch.matmul(x, self.matrix)
# torch.add performs better than torch.mm and got choosed during tuning
def matmul_cpu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
torch.add(a, b, out=out)
def matmul_dup(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
torch.add(a, b, out=out)
def matmul_cuda(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
torch.add(a, b, out=out)
class TestInductorExternalCallable(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._saved_config = config.save_config()
def tearDown(self):
super().tearDown()
config.load_config(self._saved_config)
def test_matmul_cpu(self):
# 2I + 2I == (2I)(2I)
x = torch.eye(128, 128) * 2
opt_fn = torch.compile(
MatMulModule(),
options={"max_autotune": True, "external_matmul": [matmul_cpu]},
)
opt_fn_golden = torch.compile(MatMulModule(), options={"max_autotune": True})
torch.testing.assert_close(
opt_fn(x),
opt_fn_golden(x),
msg=f"torch.compile(..., external_matmul = {matmul_cpu}) failed",
)
def test_matmul_dup(self):
# 2I + 2I == (2I)(2I)
x = torch.eye(128, 128) * 2
# This should only register the first external call
opt_fn = torch.compile(
MatMulModule(),
options={"max_autotune": True, "external_matmul": [matmul_dup, matmul_dup]},
)
opt_fn_golden = torch.compile(MatMulModule(), options={"max_autotune": True})
torch.testing.assert_close(
opt_fn(x),
opt_fn_golden(x),
msg=f"torch.compile(..., external_matmul = {matmul_dup}) failed",
)
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability() < (7, 0),
"Triton does not support device capability < 7.0",
)
def test_matmul_cuda(self):
device = torch.device("cuda")
x = (torch.eye(128, 128) * 2).to(device=device)
opt_fn = torch.compile(
MatMulModule().to(device),
options={"max_autotune": True, "external_matmul": [matmul_cuda]},
)
opt_fn_golden = torch.compile(
MatMulModule().to(device), options={"max_autotune": True}
)
torch.testing.assert_close(
opt_fn(x),
opt_fn_golden(x),
msg=f"torch.compile(..., external_matmul = {matmul_cuda}) failed",
)
if __name__ == "__main__":
run_tests()

View File

@ -1267,6 +1267,9 @@ _cache_config_ignore_prefix = [
"compile_threads",
]
# External callable for matmul tuning candidates
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -122,6 +122,13 @@ mm_template = TritonTemplate(
""",
)
# prevent duplication registration of extern functions
@functools.lru_cache(None)
def lazy_register_extern_choice(fn):
return ExternKernelChoice(fn)
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
aten_addmm = ExternKernelChoice(
@ -245,6 +252,9 @@ def tuned_mm(mat1, mat2, *, layout=None):
log.warning("No choices for GEMM, using ATen backend as fallback")
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
for k in inductor_config.external_matmul:
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
try:
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
except NoValidChoicesError: