mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] External callable registration API for Matmul tuning candidates (#130774)
Fixes #[130769](https://github.com/pytorch/pytorch/issues/130769) Pull Request resolved: https://github.com/pytorch/pytorch/pull/130774 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@meta.com>
This commit is contained in:
94
test/inductor/test_external_callables.py
Normal file
94
test/inductor/test_external_callables.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
|
||||
|
||||
class MatMulModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matrix = torch.nn.Parameter(torch.eye(128, 128) * 2, requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.matmul(x, self.matrix)
|
||||
|
||||
|
||||
# torch.add performs better than torch.mm and got choosed during tuning
|
||||
def matmul_cpu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.add(a, b, out=out)
|
||||
|
||||
|
||||
def matmul_dup(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.add(a, b, out=out)
|
||||
|
||||
|
||||
def matmul_cuda(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.add(a, b, out=out)
|
||||
|
||||
|
||||
class TestInductorExternalCallable(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._saved_config = config.save_config()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
config.load_config(self._saved_config)
|
||||
|
||||
def test_matmul_cpu(self):
|
||||
# 2I + 2I == (2I)(2I)
|
||||
x = torch.eye(128, 128) * 2
|
||||
opt_fn = torch.compile(
|
||||
MatMulModule(),
|
||||
options={"max_autotune": True, "external_matmul": [matmul_cpu]},
|
||||
)
|
||||
opt_fn_golden = torch.compile(MatMulModule(), options={"max_autotune": True})
|
||||
torch.testing.assert_close(
|
||||
opt_fn(x),
|
||||
opt_fn_golden(x),
|
||||
msg=f"torch.compile(..., external_matmul = {matmul_cpu}) failed",
|
||||
)
|
||||
|
||||
def test_matmul_dup(self):
|
||||
# 2I + 2I == (2I)(2I)
|
||||
x = torch.eye(128, 128) * 2
|
||||
# This should only register the first external call
|
||||
opt_fn = torch.compile(
|
||||
MatMulModule(),
|
||||
options={"max_autotune": True, "external_matmul": [matmul_dup, matmul_dup]},
|
||||
)
|
||||
opt_fn_golden = torch.compile(MatMulModule(), options={"max_autotune": True})
|
||||
torch.testing.assert_close(
|
||||
opt_fn(x),
|
||||
opt_fn_golden(x),
|
||||
msg=f"torch.compile(..., external_matmul = {matmul_dup}) failed",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available() and torch.cuda.get_device_capability() < (7, 0),
|
||||
"Triton does not support device capability < 7.0",
|
||||
)
|
||||
def test_matmul_cuda(self):
|
||||
device = torch.device("cuda")
|
||||
x = (torch.eye(128, 128) * 2).to(device=device)
|
||||
opt_fn = torch.compile(
|
||||
MatMulModule().to(device),
|
||||
options={"max_autotune": True, "external_matmul": [matmul_cuda]},
|
||||
)
|
||||
opt_fn_golden = torch.compile(
|
||||
MatMulModule().to(device), options={"max_autotune": True}
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
opt_fn(x),
|
||||
opt_fn_golden(x),
|
||||
msg=f"torch.compile(..., external_matmul = {matmul_cuda}) failed",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1267,6 +1267,9 @@ _cache_config_ignore_prefix = [
|
||||
"compile_threads",
|
||||
]
|
||||
|
||||
# External callable for matmul tuning candidates
|
||||
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -122,6 +122,13 @@ mm_template = TritonTemplate(
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
# prevent duplication registration of extern functions
|
||||
@functools.lru_cache(None)
|
||||
def lazy_register_extern_choice(fn):
|
||||
return ExternKernelChoice(fn)
|
||||
|
||||
|
||||
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
||||
|
||||
aten_addmm = ExternKernelChoice(
|
||||
@ -245,6 +252,9 @@ def tuned_mm(mat1, mat2, *, layout=None):
|
||||
log.warning("No choices for GEMM, using ATen backend as fallback")
|
||||
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
||||
|
||||
for k in inductor_config.external_matmul:
|
||||
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
|
||||
|
||||
try:
|
||||
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
|
||||
except NoValidChoicesError:
|
||||
|
||||
Reference in New Issue
Block a user