mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][TunableOp] Unit test to verify that there is only one kernel launch per PyTorch API invocation. (#155077)
TunableOp UT covers breakage that was fixed in this PR: https://github.com/pytorch/pytorch/pull/153764 After tuning is complete, verify that there is only one kernel launch. for each PyTorch API invocation Pull Request resolved: https://github.com/pytorch/pytorch/pull/155077 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
08d15d3ec1
commit
e8d29c45e0
@ -5915,6 +5915,56 @@ class TestLinalg(TestCase):
|
||||
delta = tuned_default_scaled_mm - ref_scaled_mm
|
||||
self.assertTrue(torch.all(delta == 0))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
def test_call_count_tunableop(self, device, dtype):
|
||||
# Test that after tuning a GEMM in TunableOp, we only call the GEMM kernel once
|
||||
# per PyTorch API invocation.
|
||||
# We use the torch profiler to get the call counts on the kernels
|
||||
|
||||
# Supported only for: MM, batch MM, and GEMM with bias (linear)
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with self._tunableop_ctx():
|
||||
# set these to single iterations to keep it short but still exercise the code
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
|
||||
b = 2
|
||||
M = 10
|
||||
|
||||
# MM
|
||||
A = torch.rand(M, M, device=device)
|
||||
C = torch.mm(A, A)
|
||||
|
||||
# Linear - GEMM BIAS
|
||||
X = torch.rand(M, M, device='cuda')
|
||||
bias = torch.rand(M, device='cuda')
|
||||
Y = torch.nn.functional.linear(X, A, bias)
|
||||
|
||||
# BMM
|
||||
batch_A = torch.rand((b, M, M), device='cuda')
|
||||
batch_C = torch.bmm(batch_A, batch_A)
|
||||
|
||||
kernel_count = 0
|
||||
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
|
||||
C = torch.mm(A, A)
|
||||
Y = torch.nn.functional.linear(X, A, bias)
|
||||
batch_C = torch.bmm(batch_A, batch_A)
|
||||
|
||||
# Check that after tuning, there was only one kernel
|
||||
# launched per PyTorch API. The kernels have string
|
||||
# that always starts with `Cijk*`
|
||||
mm_key = 'Cijk'
|
||||
events = prof.key_averages()
|
||||
for evt in events:
|
||||
if mm_key in evt.key:
|
||||
self.assertEqual(evt.count, 1)
|
||||
kernel_count = kernel_count + 1
|
||||
|
||||
# There must be exactly three kernels only
|
||||
self.assertEqual(kernel_count, 3)
|
||||
|
||||
@dtypes(torch.float, torch.complex64)
|
||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
||||
|
Reference in New Issue
Block a user