[ROCm][TunableOp] Unit test to verify that there is only one kernel launch per PyTorch API invocation. (#155077)

TunableOp UT covers breakage that was fixed in this PR: https://github.com/pytorch/pytorch/pull/153764

After tuning is complete, verify that there is only one kernel launch. for each PyTorch API invocation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155077
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero
2025-06-10 16:11:43 +00:00
committed by PyTorch MergeBot
parent 08d15d3ec1
commit e8d29c45e0

View File

@ -5915,6 +5915,56 @@ class TestLinalg(TestCase):
delta = tuned_default_scaled_mm - ref_scaled_mm
self.assertTrue(torch.all(delta == 0))
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float)
def test_call_count_tunableop(self, device, dtype):
# Test that after tuning a GEMM in TunableOp, we only call the GEMM kernel once
# per PyTorch API invocation.
# We use the torch profiler to get the call counts on the kernels
# Supported only for: MM, batch MM, and GEMM with bias (linear)
from torch.profiler import profile, ProfilerActivity
with self._tunableop_ctx():
# set these to single iterations to keep it short but still exercise the code
torch.cuda.tunable.set_max_tuning_iterations(1)
b = 2
M = 10
# MM
A = torch.rand(M, M, device=device)
C = torch.mm(A, A)
# Linear - GEMM BIAS
X = torch.rand(M, M, device='cuda')
bias = torch.rand(M, device='cuda')
Y = torch.nn.functional.linear(X, A, bias)
# BMM
batch_A = torch.rand((b, M, M), device='cuda')
batch_C = torch.bmm(batch_A, batch_A)
kernel_count = 0
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
C = torch.mm(A, A)
Y = torch.nn.functional.linear(X, A, bias)
batch_C = torch.bmm(batch_A, batch_A)
# Check that after tuning, there was only one kernel
# launched per PyTorch API. The kernels have string
# that always starts with `Cijk*`
mm_key = 'Cijk'
events = prof.key_averages()
for evt in events:
if mm_key in evt.key:
self.assertEqual(evt.count, 1)
kernel_count = kernel_count + 1
# There must be exactly three kernels only
self.assertEqual(kernel_count, 3)
@dtypes(torch.float, torch.complex64)
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)