mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add tests to bsr_dense_addmm_meta. Tune bsr_dense_addmm kernel for ViT shapes. (#132646)
As in the title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132646 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
b7bcfdaff2
commit
1471473b84
@ -2,10 +2,12 @@
|
||||
|
||||
import torch
|
||||
import random
|
||||
import io
|
||||
import itertools
|
||||
import unittest
|
||||
import functools
|
||||
from torch.testing import make_tensor
|
||||
from contextlib import redirect_stderr
|
||||
from torch.testing import make_tensor, FileCheck
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests,
|
||||
@ -4134,6 +4136,76 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
result = operation(*args, **dict(meta=meta))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
|
||||
def test_triton_bsr_dense_addmm_meta(self, device):
|
||||
from torch.sparse._triton_ops import bsr_dense_addmm_meta
|
||||
from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta
|
||||
|
||||
dtype = torch.float32
|
||||
Ms = Ks = 16
|
||||
beta = 0.0
|
||||
alpha = 1.0
|
||||
|
||||
def get_meta(M, K, N, sparsity=None):
|
||||
return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity,
|
||||
_version="test_triton_bsr_dense_addmm_meta")
|
||||
|
||||
def update_meta(M, K, N, value, sparsity=0.5):
|
||||
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
|
||||
update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(),
|
||||
("test_triton_bsr_dense_addmm_meta", dtype, sparsity),
|
||||
key, value)
|
||||
|
||||
def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None):
|
||||
f = io.StringIO()
|
||||
with redirect_stderr(f):
|
||||
result = get_meta(M, K, N, sparsity=sparsity)
|
||||
msg = f.getvalue()
|
||||
FileCheck().check_count(
|
||||
str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}",
|
||||
count=warn_count, exactly=True
|
||||
).run(msg)
|
||||
return result
|
||||
|
||||
# Test warn_once when requesting non-existing tuned parameters multiple times
|
||||
f = io.StringIO()
|
||||
with redirect_stderr(f):
|
||||
for i in range(5):
|
||||
get_meta(16, 16, 16)
|
||||
for i in range(5):
|
||||
get_meta(16, 16, 32)
|
||||
|
||||
msg = f.getvalue()
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True
|
||||
).run(msg)
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True
|
||||
).run(msg)
|
||||
|
||||
# Test warn_once when tuned parameters are missing
|
||||
default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4)
|
||||
self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta)
|
||||
|
||||
# Test (no)warn_once when tuned parameters are available
|
||||
update_meta(32, 32, 48, (2, 8, 5, 6))
|
||||
expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6)
|
||||
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta)
|
||||
|
||||
# Test non-existing tuned parameters with non-default sparsity
|
||||
# while for default sparsity 0.5 the parameters are available
|
||||
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta)
|
||||
|
||||
# Test non-existing tuned parameters while there exists
|
||||
# parameters with consistent N // SPLIT_N ratio:
|
||||
self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0),
|
||||
dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6))
|
||||
# ... or not:
|
||||
self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1),
|
||||
dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4))
|
||||
|
||||
|
||||
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
|
||||
instantiate_device_type_tests(TestSparseCSR, globals())
|
||||
|
||||
Reference in New Issue
Block a user