Files
pytorch/test/inductor/test_native_matmul.py

158 lines
4.6 KiB
Python

# Owner(s): ["module: inductor"]
from typing import Callable
import torch
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_triton_code
from torch.testing import FileCheck
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
aten = torch.ops.aten
@inductor_config.patch({"triton.native_matmul": True})
class TestTritonDotReduction(TestCase):
def _check_equal(
self,
f: Callable,
example_inputs: tuple[torch.Tensor],
):
compiled = torch.compile(f)
actual = compiled(*example_inputs)
expect = f(*example_inputs)
self.assertTrue(same(expect, actual))
def _check_code(
self,
f: Callable,
example_inputs: tuple[torch.Tensor],
kernel_count: int,
dot_count: int,
):
f = torch.compile(f)
code = run_and_get_triton_code(f, *example_inputs)
FileCheck().check_regex(r"triton.*mm.*\.run\(").run(code)
FileCheck().check_count(
"@triton.jit",
kernel_count,
).check_count(
"tl.dot",
dot_count,
).run(code)
def test_matmul(self):
def f(x, y):
z = x @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
def test_mm_1d_expand(self):
def f(x, y, M, K):
z = x[:, None].expand(M, K) @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((M,), (1,), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, M, K))
self._check_code(f, (x, y, M, K), 1, 1)
def test_mm_2_expand(self):
def f(x, y, M, K):
z = x[:, None].expand(M, K) @ y
return z
M, K, N = 128, 128, 128
x = rand_strided((1,), (0,), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, M, K))
self._check_code(f, (x, y, M, K), 1, 1)
def test_matmul_fp16(self):
def f(x, y):
z = x @ y.to(x.dtype)
return z
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), dtype=torch.float16, device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), dtype=torch.float32, device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
def test_reduction_mask_zeroout(self):
def f(x, y):
return (x + 1) @ (y - 2)
M, K, N = 62, 62, 62
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
def test_3mm_add(self):
def f(x, y, z, w, r, t):
return x @ y + z @ w + r @ t
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
w = rand_strided((M, K), (K, 1), device=GPU_TYPE)
z = rand_strided((K, N), (N, 1), device=GPU_TYPE)
r = rand_strided((M, K), (K, 1), device=GPU_TYPE)
t = rand_strided((K, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, z, w, r, t))
self._check_code(f, (x, y, z, w, r, t), 1, 3)
def test_mm_complex(self):
def f(x, y, z, w):
return x[z] @ y + w + 3
M, K, N = 128, 128, 128
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
z = torch.randint(M, (M, K), dtype=torch.long, device=GPU_TYPE)
w = rand_strided((M, N), (N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y, z, w))
self._check_code(f, (x, y, z, w), 1, 1)
def test_batchmatmul(self):
def f(x, y):
z = torch.bmm(x, y)
return z
B, M, K, N = 256, 128, 128, 128
x = rand_strided((B, M, K), (M * K, K, 1), device=GPU_TYPE)
y = rand_strided((B, K, N), (K * N, N, 1), device=GPU_TYPE)
self._check_equal(f, (x, y))
self._check_code(f, (x, y), 1, 1)
if HAS_GPU:
torch.set_default_device(GPU_TYPE)
if __name__ == "__main__":
# TODO: support native matmul on xpu
if HAS_GPU and GPU_TYPE != "xpu":
run_tests()