[ROCm][TunableOp] Fix offline tuning for ScaledGEMM. (#149677)

The main purpose of this PR is to fix offline tuning for ScaledGEMM. The previous UT passed because it was not strict enough. Additionally:
- All the offline tuning tests now do a comparison with the online results to ensure that ParamSignature match.
- We raise an error if submatrices are encountered as this is only supported in online tuning mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149677
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero
2025-03-22 02:22:10 +00:00
committed by PyTorch MergeBot
parent b9a5e1d038
commit 01b1d1f91b
2 changed files with 94 additions and 11 deletions

View File

@ -107,6 +107,31 @@ def find_tunableop_result(results, OpSig, ParamSig):
return inner_tuple return inner_tuple
return None return None
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
# Compare Param Signature of untuned and tuned Tunableop results
# file. Verify that for each Param Signature in the untuned file
# there is a matching one in the tuned results file.
import csv
ok = False
with open(untuned_filename) as file1:
with open(tuned_filename) as file2:
untuned_reader = csv.reader(file1)
untuned_csv_entries = [row[1] for row in untuned_reader]
tuned_reader = csv.reader(file2)
for _ in range(5): # Skip the first 5 lines for the validator
next(tuned_reader, None)
result_csv_entries = [row[1] for row in tuned_reader]
for value in untuned_csv_entries:
if value in result_csv_entries:
ok = True
else:
ok = False
return ok
class TestLinalg(TestCase): class TestLinalg(TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def _hip_allow_tf32(self): def _hip_allow_tf32(self):
@ -4680,6 +4705,7 @@ class TestLinalg(TestCase):
pass pass
@onlyCUDA @onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.half) @dtypes(torch.half)
def test_matmul_offline_tunableop(self, device, dtype): def test_matmul_offline_tunableop(self, device, dtype):
import os import os
@ -4733,6 +4759,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename)) self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0) self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally: finally:
# disable TunableOp # disable TunableOp
torch.cuda.tunable.enable(False) torch.cuda.tunable.enable(False)
@ -4780,12 +4810,14 @@ class TestLinalg(TestCase):
# Scaled GEMM parameters # Scaled GEMM parameters
fillA = 0.25 fillA = 0.25
fillB = 0.75 fillB = 0.75
m = n = k = 16 n = 16
m = 32
k = 64
scaleA = torch.tensor(0.8, device=device) scaleA = torch.tensor(0.8, device=device)
scaleB = torch.tensor(0.9, device=device) scaleB = torch.tensor(0.9, device=device)
dtypeA = dtypeB = dtype dtypeA = dtypeB = dtype
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device) matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t() matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
# Summary of bias types that are supported: # Summary of bias types that are supported:
@ -4813,7 +4845,7 @@ class TestLinalg(TestCase):
# rowwise scaling, only supported for this dtype combination # rowwise scaling, only supported for this dtype combination
if dtype is torch.torch.float8_e4m3fnuz: if dtype is torch.torch.float8_e4m3fnuz:
scaleA = torch.ones((matA.shape[0], 1), device=device) scaleA = torch.ones((matA.shape[0], 1), device=device)
scaleB = torch.ones((1, matB.shape[0]), device=device) scaleB = torch.ones((1, matB.shape[1]), device=device)
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16) torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.is_enabled())
@ -4850,6 +4882,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename)) self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0) self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally: finally:
# disable TunableOp # disable TunableOp
torch.cuda.tunable.enable(False) torch.cuda.tunable.enable(False)
@ -5328,6 +5364,7 @@ class TestLinalg(TestCase):
pass pass
@onlyCUDA @onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.bfloat16) @dtypes(torch.bfloat16)
def test_gemm_bias_offline_tunableop(self, device, dtype): def test_gemm_bias_offline_tunableop(self, device, dtype):
# This test is the offline version of test_gemm_bias_tunableop # This test is the offline version of test_gemm_bias_tunableop
@ -5390,6 +5427,10 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename)) self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0) self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally: finally:
# disable TunableOp # disable TunableOp
torch.cuda.tunable.enable(False) torch.cuda.tunable.enable(False)
@ -5435,12 +5476,14 @@ class TestLinalg(TestCase):
# Scaled GEMM parameters # Scaled GEMM parameters
fillA = 0.25 fillA = 0.25
fillB = 0.75 fillB = 0.75
m = n = k = 32 n = 32
m = 64
k = 128
scaleA = torch.tensor(0.8, device=device) scaleA = torch.tensor(0.8, device=device)
scaleB = torch.tensor(0.9, device=device) scaleB = torch.tensor(0.9, device=device)
dtypeA = dtypeB = dtype dtypeA = dtypeB = dtype
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device) matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t() matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
# Summary of bias types that are supported: # Summary of bias types that are supported:
@ -5468,7 +5511,7 @@ class TestLinalg(TestCase):
# rowwise scaling, only supported for this dtype combination # rowwise scaling, only supported for this dtype combination
if dtype is torch.torch.float8_e4m3fnuz: if dtype is torch.torch.float8_e4m3fnuz:
scaleA = torch.ones((matA.shape[0], 1), device=device) scaleA = torch.ones((matA.shape[0], 1), device=device)
scaleB = torch.ones((1, matB.shape[0]), device=device) scaleB = torch.ones((1, matB.shape[1]), device=device)
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16) torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
# This stores total number of cummulative results # This stores total number of cummulative results
@ -5645,6 +5688,16 @@ class TestLinalg(TestCase):
'nn_41_41_41_ld_41_41_41') 'nn_41_41_41_ld_41_41_41')
self.assertTrue(found_result is not None) self.assertTrue(found_result is not None)
self.assertTrue(torch.cuda.tunable.write_file())
# Make sure the results file exists and that it is not zero
self.assertTrue(os.path.exists(result_filename))
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
self.assertTrue(ok)
finally: finally:
# Disable TF32 # Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False

View File

@ -478,10 +478,11 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
else: # ScaledGEMM else: # ScaledGEMM
count = untuned_gemm[0].count("_")
assert count in [6, 7]
untuned_gemm_temp = untuned_gemm[0].split("_") untuned_gemm_temp = untuned_gemm[0].split("_")
# dtypeC = might not be FP8 type, keep track # dtypeC = might not be FP8 type, keep track
# of the the number of underscores # of the the number of underscores
count = untuned_gemm_temp.count("_")
op_sig = untuned_gemm_temp[0] op_sig = untuned_gemm_temp[0]
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2] data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4] data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
@ -497,6 +498,23 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
untuned_gemm_temp = untuned_gemm[1].split("_") untuned_gemm_temp = untuned_gemm[1].split("_")
[n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]] [n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
if op_sig == "GemmStridedBatchedTunableOp":
assert untuned_gemm_temp[6] == "ld"
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
else:
assert untuned_gemm_temp[4] == "ld"
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]
# We cannot handle submatrices in offline tuning
if all(item in [n, m, k] for item in [lda, ldb, ldc]):
pass
else:
warnings.warn(
"Offline tuning is not supported on submatrices. Use online tuning instead. "
+ f"Skipped tuning for: {untuned_gemm[1]}"
)
return
if op_sig == "GemmTunableOp": if op_sig == "GemmTunableOp":
matA = ( matA = (
torch.rand(k, m, dtype=dtype, device=deviceid).t() torch.rand(k, m, dtype=dtype, device=deviceid).t()
@ -525,6 +543,10 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
matB = matB.transpose(1, 2) if transA else matB matB = matB.transpose(1, 2) if transA else matB
torch.bmm(matA, matB) torch.bmm(matA, matB)
elif op_sig == "ScaledGemmTunableOp": elif op_sig == "ScaledGemmTunableOp":
# Only combination supported by PyTorch
assert transA is True
assert transB is False
fillA = 0.25 fillA = 0.25
fillB = 0.75 fillB = 0.75
matA = ( matA = (
@ -533,9 +555,9 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid) else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
) )
matB = ( matB = (
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid) torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
if transA if transA
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid).t() else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
) )
assert untuned_gemm_temp[8] == "rw" assert untuned_gemm_temp[8] == "rw"
@ -544,8 +566,16 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
else: else:
rowwise = False rowwise = False
if rowwise: if rowwise:
scaleA = torch.ones((matA.shape[0], 1), device=deviceid) scaleA = (
scaleB = torch.ones((1, matB.shape[0]), device=deviceid) torch.ones((1, m), device=deviceid)
if transB
else torch.ones((m, 1), device=deviceid)
)
scaleB = (
torch.ones((1, n), device=deviceid)
if transA
else torch.ones((n, 1), device=deviceid)
)
else: else:
scaleA = torch.tensor(0.8, device=deviceid) scaleA = torch.tensor(0.8, device=deviceid)
scaleB = torch.tensor(0.9, device=deviceid) scaleB = torch.tensor(0.9, device=deviceid)