mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[ROCm][TunableOp] Fix offline tuning for ScaledGEMM. (#149677)
The main purpose of this PR is to fix offline tuning for ScaledGEMM. The previous UT passed because it was not strict enough. Additionally: - All the offline tuning tests now do a comparison with the online results to ensure that ParamSignature match. - We raise an error if submatrices are encountered as this is only supported in online tuning mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149677 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9a5e1d038
commit
01b1d1f91b
@ -107,6 +107,31 @@ def find_tunableop_result(results, OpSig, ParamSig):
|
||||
return inner_tuple
|
||||
return None
|
||||
|
||||
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
|
||||
# Compare Param Signature of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Param Signature in the untuned file
|
||||
# there is a matching one in the tuned results file.
|
||||
import csv
|
||||
ok = False
|
||||
with open(untuned_filename) as file1:
|
||||
with open(tuned_filename) as file2:
|
||||
untuned_reader = csv.reader(file1)
|
||||
untuned_csv_entries = [row[1] for row in untuned_reader]
|
||||
|
||||
tuned_reader = csv.reader(file2)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(tuned_reader, None)
|
||||
|
||||
result_csv_entries = [row[1] for row in tuned_reader]
|
||||
|
||||
for value in untuned_csv_entries:
|
||||
if value in result_csv_entries:
|
||||
ok = True
|
||||
else:
|
||||
ok = False
|
||||
|
||||
return ok
|
||||
|
||||
class TestLinalg(TestCase):
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32(self):
|
||||
@ -4680,6 +4705,7 @@ class TestLinalg(TestCase):
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.half)
|
||||
def test_matmul_offline_tunableop(self, device, dtype):
|
||||
import os
|
||||
@ -4733,6 +4759,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -4780,12 +4810,14 @@ class TestLinalg(TestCase):
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
m = n = k = 16
|
||||
n = 16
|
||||
m = 32
|
||||
k = 64
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
dtypeA = dtypeB = dtype
|
||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
||||
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||
|
||||
# Summary of bias types that are supported:
|
||||
@ -4813,7 +4845,7 @@ class TestLinalg(TestCase):
|
||||
# rowwise scaling, only supported for this dtype combination
|
||||
if dtype is torch.torch.float8_e4m3fnuz:
|
||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
@ -4850,6 +4882,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -5328,6 +5364,7 @@ class TestLinalg(TestCase):
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_gemm_bias_offline_tunableop(self, device, dtype):
|
||||
# This test is the offline version of test_gemm_bias_tunableop
|
||||
@ -5390,6 +5427,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -5435,12 +5476,14 @@ class TestLinalg(TestCase):
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
m = n = k = 32
|
||||
n = 32
|
||||
m = 64
|
||||
k = 128
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
dtypeA = dtypeB = dtype
|
||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
||||
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||
|
||||
# Summary of bias types that are supported:
|
||||
@ -5468,7 +5511,7 @@ class TestLinalg(TestCase):
|
||||
# rowwise scaling, only supported for this dtype combination
|
||||
if dtype is torch.torch.float8_e4m3fnuz:
|
||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||
|
||||
# This stores total number of cummulative results
|
||||
@ -5645,6 +5688,16 @@ class TestLinalg(TestCase):
|
||||
'nn_41_41_41_ld_41_41_41')
|
||||
self.assertTrue(found_result is not None)
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# Disable TF32
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
@ -478,10 +478,11 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
else: # ScaledGEMM
|
||||
count = untuned_gemm[0].count("_")
|
||||
assert count in [6, 7]
|
||||
untuned_gemm_temp = untuned_gemm[0].split("_")
|
||||
# dtypeC = might not be FP8 type, keep track
|
||||
# of the the number of underscores
|
||||
count = untuned_gemm_temp.count("_")
|
||||
op_sig = untuned_gemm_temp[0]
|
||||
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
|
||||
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
|
||||
@ -497,6 +498,23 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
|
||||
untuned_gemm_temp = untuned_gemm[1].split("_")
|
||||
[n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
|
||||
if op_sig == "GemmStridedBatchedTunableOp":
|
||||
assert untuned_gemm_temp[6] == "ld"
|
||||
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
|
||||
else:
|
||||
assert untuned_gemm_temp[4] == "ld"
|
||||
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]
|
||||
|
||||
# We cannot handle submatrices in offline tuning
|
||||
if all(item in [n, m, k] for item in [lda, ldb, ldc]):
|
||||
pass
|
||||
else:
|
||||
warnings.warn(
|
||||
"Offline tuning is not supported on submatrices. Use online tuning instead. "
|
||||
+ f"Skipped tuning for: {untuned_gemm[1]}"
|
||||
)
|
||||
return
|
||||
|
||||
if op_sig == "GemmTunableOp":
|
||||
matA = (
|
||||
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
||||
@ -525,6 +543,10 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
matB = matB.transpose(1, 2) if transA else matB
|
||||
torch.bmm(matA, matB)
|
||||
elif op_sig == "ScaledGemmTunableOp":
|
||||
# Only combination supported by PyTorch
|
||||
assert transA is True
|
||||
assert transB is False
|
||||
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
matA = (
|
||||
@ -533,9 +555,9 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
|
||||
)
|
||||
matB = (
|
||||
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid)
|
||||
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
|
||||
if transA
|
||||
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid).t()
|
||||
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
|
||||
)
|
||||
|
||||
assert untuned_gemm_temp[8] == "rw"
|
||||
@ -544,8 +566,16 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
else:
|
||||
rowwise = False
|
||||
if rowwise:
|
||||
scaleA = torch.ones((matA.shape[0], 1), device=deviceid)
|
||||
scaleB = torch.ones((1, matB.shape[0]), device=deviceid)
|
||||
scaleA = (
|
||||
torch.ones((1, m), device=deviceid)
|
||||
if transB
|
||||
else torch.ones((m, 1), device=deviceid)
|
||||
)
|
||||
scaleB = (
|
||||
torch.ones((1, n), device=deviceid)
|
||||
if transA
|
||||
else torch.ones((n, 1), device=deviceid)
|
||||
)
|
||||
else:
|
||||
scaleA = torch.tensor(0.8, device=deviceid)
|
||||
scaleB = torch.tensor(0.9, device=deviceid)
|
||||
|
Reference in New Issue
Block a user