[ROCm][TunableOp] Fix offline tuning for ScaledGEMM. (#149677)

The main purpose of this PR is to fix offline tuning for ScaledGEMM. The previous UT passed because it was not strict enough. Additionally:
- All the offline tuning tests now do a comparison with the online results to ensure that ParamSignature match.
- We raise an error if submatrices are encountered as this is only supported in online tuning mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149677
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero
2025-03-22 02:22:10 +00:00
committed by PyTorch MergeBot
parent b9a5e1d038
commit 01b1d1f91b
2 changed files with 94 additions and 11 deletions

View File

@ -107,6 +107,31 @@ def find_tunableop_result(results, OpSig, ParamSig):
return inner_tuple
return None
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
# Compare Param Signature of untuned and tuned Tunableop results
# file. Verify that for each Param Signature in the untuned file
# there is a matching one in the tuned results file.
import csv
ok = False
with open(untuned_filename) as file1:
with open(tuned_filename) as file2:
untuned_reader = csv.reader(file1)
untuned_csv_entries = [row[1] for row in untuned_reader]
tuned_reader = csv.reader(file2)
for _ in range(5): # Skip the first 5 lines for the validator
next(tuned_reader, None)
result_csv_entries = [row[1] for row in tuned_reader]
for value in untuned_csv_entries:
if value in result_csv_entries:
ok = True
else:
ok = False
return ok
class TestLinalg(TestCase):
@contextlib.contextmanager
def _hip_allow_tf32(self):
@ -4680,6 +4705,7 @@ class TestLinalg(TestCase):
pass
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.half)
def test_matmul_offline_tunableop(self, device, dtype):
import os
@ -4733,6 +4759,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally:
# disable TunableOp
torch.cuda.tunable.enable(False)
@ -4780,12 +4810,14 @@ class TestLinalg(TestCase):
# Scaled GEMM parameters
fillA = 0.25
fillB = 0.75
m = n = k = 16
n = 16
m = 32
k = 64
scaleA = torch.tensor(0.8, device=device)
scaleB = torch.tensor(0.9, device=device)
dtypeA = dtypeB = dtype
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
# Summary of bias types that are supported:
@ -4813,7 +4845,7 @@ class TestLinalg(TestCase):
# rowwise scaling, only supported for this dtype combination
if dtype is torch.torch.float8_e4m3fnuz:
scaleA = torch.ones((matA.shape[0], 1), device=device)
scaleB = torch.ones((1, matB.shape[0]), device=device)
scaleB = torch.ones((1, matB.shape[1]), device=device)
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
self.assertTrue(torch.cuda.tunable.is_enabled())
@ -4850,6 +4882,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally:
# disable TunableOp
torch.cuda.tunable.enable(False)
@ -5328,6 +5364,7 @@ class TestLinalg(TestCase):
pass
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.bfloat16)
def test_gemm_bias_offline_tunableop(self, device, dtype):
# This test is the offline version of test_gemm_bias_tunableop
@ -5390,6 +5427,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally:
# disable TunableOp
torch.cuda.tunable.enable(False)
@ -5435,12 +5476,14 @@ class TestLinalg(TestCase):
# Scaled GEMM parameters
fillA = 0.25
fillB = 0.75
m = n = k = 32
n = 32
m = 64
k = 128
scaleA = torch.tensor(0.8, device=device)
scaleB = torch.tensor(0.9, device=device)
dtypeA = dtypeB = dtype
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
# Summary of bias types that are supported:
@ -5468,7 +5511,7 @@ class TestLinalg(TestCase):
# rowwise scaling, only supported for this dtype combination
if dtype is torch.torch.float8_e4m3fnuz:
scaleA = torch.ones((matA.shape[0], 1), device=device)
scaleB = torch.ones((1, matB.shape[0]), device=device)
scaleB = torch.ones((1, matB.shape[1]), device=device)
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
# This stores total number of cummulative results
@ -5645,6 +5688,16 @@ class TestLinalg(TestCase):
'nn_41_41_41_ld_41_41_41')
self.assertTrue(found_result is not None)
self.assertTrue(torch.cuda.tunable.write_file())
# Make sure the results file exists and that it is not zero
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally:
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False