Add tests to bsr_dense_addmm_meta. Tune bsr_dense_addmm kernel for ViT shapes. (#132646)

As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132646
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2024-08-05 10:57:50 +00:00
committed by PyTorch MergeBot
parent b7bcfdaff2
commit 1471473b84
3 changed files with 691 additions and 17 deletions

View File

@ -2,10 +2,12 @@
import torch
import random
import io
import itertools
import unittest
import functools
from torch.testing import make_tensor
from contextlib import redirect_stderr
from torch.testing import make_tensor, FileCheck
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
@ -4134,6 +4136,76 @@ class TestSparseCompressedTritonKernels(TestCase):
result = operation(*args, **dict(meta=meta))
self.assertEqual(result, expected)
@onlyCUDA
@skipIfRocm
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_addmm_meta(self, device):
from torch.sparse._triton_ops import bsr_dense_addmm_meta
from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta
dtype = torch.float32
Ms = Ks = 16
beta = 0.0
alpha = 1.0
def get_meta(M, K, N, sparsity=None):
return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity,
_version="test_triton_bsr_dense_addmm_meta")
def update_meta(M, K, N, value, sparsity=0.5):
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(),
("test_triton_bsr_dense_addmm_meta", dtype, sparsity),
key, value)
def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None):
f = io.StringIO()
with redirect_stderr(f):
result = get_meta(M, K, N, sparsity=sparsity)
msg = f.getvalue()
FileCheck().check_count(
str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}",
count=warn_count, exactly=True
).run(msg)
return result
# Test warn_once when requesting non-existing tuned parameters multiple times
f = io.StringIO()
with redirect_stderr(f):
for i in range(5):
get_meta(16, 16, 16)
for i in range(5):
get_meta(16, 16, 32)
msg = f.getvalue()
FileCheck().check_count(
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True
).run(msg)
FileCheck().check_count(
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True
).run(msg)
# Test warn_once when tuned parameters are missing
default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4)
self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta)
# Test (no)warn_once when tuned parameters are available
update_meta(32, 32, 48, (2, 8, 5, 6))
expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6)
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta)
# Test non-existing tuned parameters with non-default sparsity
# while for default sparsity 0.5 the parameters are available
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta)
# Test non-existing tuned parameters while there exists
# parameters with consistent N // SPLIT_N ratio:
self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0),
dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6))
# ... or not:
self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1),
dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4))
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())