mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #165719, Fixes #165771 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165714 Approved by: https://github.com/jansel
158 lines
4.6 KiB
Python
158 lines
4.6 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import config as inductor_config
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import run_and_get_triton_code
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
@inductor_config.patch({"triton.native_matmul": True})
|
|
class TestTritonDotReduction(TestCase):
|
|
def _check_equal(
|
|
self,
|
|
f: Callable,
|
|
example_inputs: tuple[torch.Tensor],
|
|
):
|
|
compiled = torch.compile(f)
|
|
actual = compiled(*example_inputs)
|
|
expect = f(*example_inputs)
|
|
self.assertTrue(same(expect, actual))
|
|
|
|
def _check_code(
|
|
self,
|
|
f: Callable,
|
|
example_inputs: tuple[torch.Tensor],
|
|
kernel_count: int,
|
|
dot_count: int,
|
|
):
|
|
f = torch.compile(f)
|
|
code = run_and_get_triton_code(f, *example_inputs)
|
|
FileCheck().check_regex(r"triton.*mm.*\.run\(").run(code)
|
|
|
|
FileCheck().check_count(
|
|
"@triton.jit",
|
|
kernel_count,
|
|
).check_count(
|
|
"tl.dot",
|
|
dot_count,
|
|
).run(code)
|
|
|
|
def test_matmul(self):
|
|
def f(x, y):
|
|
z = x @ y
|
|
return z
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y))
|
|
self._check_code(f, (x, y), 1, 1)
|
|
|
|
def test_mm_1d_expand(self):
|
|
def f(x, y, M, K):
|
|
z = x[:, None].expand(M, K) @ y
|
|
return z
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((M,), (1,), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y, M, K))
|
|
self._check_code(f, (x, y, M, K), 1, 1)
|
|
|
|
def test_mm_2_expand(self):
|
|
def f(x, y, M, K):
|
|
z = x[:, None].expand(M, K) @ y
|
|
return z
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((1,), (0,), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y, M, K))
|
|
self._check_code(f, (x, y, M, K), 1, 1)
|
|
|
|
def test_matmul_fp16(self):
|
|
def f(x, y):
|
|
z = x @ y.to(x.dtype)
|
|
return z
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((M, K), (K, 1), dtype=torch.float16, device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), dtype=torch.float32, device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y))
|
|
self._check_code(f, (x, y), 1, 1)
|
|
|
|
def test_reduction_mask_zeroout(self):
|
|
def f(x, y):
|
|
return (x + 1) @ (y - 2)
|
|
|
|
M, K, N = 62, 62, 62
|
|
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y))
|
|
self._check_code(f, (x, y), 1, 1)
|
|
|
|
def test_3mm_add(self):
|
|
def f(x, y, z, w, r, t):
|
|
return x @ y + z @ w + r @ t
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
w = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
z = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
r = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
t = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y, z, w, r, t))
|
|
self._check_code(f, (x, y, z, w, r, t), 1, 3)
|
|
|
|
def test_mm_complex(self):
|
|
def f(x, y, z, w):
|
|
return x[z] @ y + w + 3
|
|
|
|
M, K, N = 128, 128, 128
|
|
x = rand_strided((M, K), (K, 1), device=GPU_TYPE)
|
|
y = rand_strided((K, N), (N, 1), device=GPU_TYPE)
|
|
|
|
z = torch.randint(M, (M, K), dtype=torch.long, device=GPU_TYPE)
|
|
w = rand_strided((M, N), (N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y, z, w))
|
|
self._check_code(f, (x, y, z, w), 1, 1)
|
|
|
|
def test_batchmatmul(self):
|
|
def f(x, y):
|
|
z = torch.bmm(x, y)
|
|
return z
|
|
|
|
B, M, K, N = 256, 128, 128, 128
|
|
x = rand_strided((B, M, K), (M * K, K, 1), device=GPU_TYPE)
|
|
y = rand_strided((B, K, N), (K * N, N, 1), device=GPU_TYPE)
|
|
|
|
self._check_equal(f, (x, y))
|
|
self._check_code(f, (x, y), 1, 1)
|
|
|
|
|
|
if HAS_GPU:
|
|
torch.set_default_device(GPU_TYPE)
|
|
|
|
if __name__ == "__main__":
|
|
# TODO: support native matmul on xpu
|
|
if HAS_GPU and GPU_TYPE != "xpu":
|
|
run_tests()
|