mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[ROCm][TunableOp] Fix offline tuning for ScaledGEMM. (#149677)
The main purpose of this PR is to fix offline tuning for ScaledGEMM. The previous UT passed because it was not strict enough. Additionally: - All the offline tuning tests now do a comparison with the online results to ensure that ParamSignature match. - We raise an error if submatrices are encountered as this is only supported in online tuning mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149677 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9a5e1d038
commit
01b1d1f91b
@ -107,6 +107,31 @@ def find_tunableop_result(results, OpSig, ParamSig):
|
|||||||
return inner_tuple
|
return inner_tuple
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
|
||||||
|
# Compare Param Signature of untuned and tuned Tunableop results
|
||||||
|
# file. Verify that for each Param Signature in the untuned file
|
||||||
|
# there is a matching one in the tuned results file.
|
||||||
|
import csv
|
||||||
|
ok = False
|
||||||
|
with open(untuned_filename) as file1:
|
||||||
|
with open(tuned_filename) as file2:
|
||||||
|
untuned_reader = csv.reader(file1)
|
||||||
|
untuned_csv_entries = [row[1] for row in untuned_reader]
|
||||||
|
|
||||||
|
tuned_reader = csv.reader(file2)
|
||||||
|
for _ in range(5): # Skip the first 5 lines for the validator
|
||||||
|
next(tuned_reader, None)
|
||||||
|
|
||||||
|
result_csv_entries = [row[1] for row in tuned_reader]
|
||||||
|
|
||||||
|
for value in untuned_csv_entries:
|
||||||
|
if value in result_csv_entries:
|
||||||
|
ok = True
|
||||||
|
else:
|
||||||
|
ok = False
|
||||||
|
|
||||||
|
return ok
|
||||||
|
|
||||||
class TestLinalg(TestCase):
|
class TestLinalg(TestCase):
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _hip_allow_tf32(self):
|
def _hip_allow_tf32(self):
|
||||||
@ -4680,6 +4705,7 @@ class TestLinalg(TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
@skipCUDAIfNotRocm
|
||||||
@dtypes(torch.half)
|
@dtypes(torch.half)
|
||||||
def test_matmul_offline_tunableop(self, device, dtype):
|
def test_matmul_offline_tunableop(self, device, dtype):
|
||||||
import os
|
import os
|
||||||
@ -4733,6 +4759,10 @@ class TestLinalg(TestCase):
|
|||||||
self.assertTrue(os.path.exists(result_filename))
|
self.assertTrue(os.path.exists(result_filename))
|
||||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||||
|
|
||||||
|
# Compare Param Signature of untuned and tuned results
|
||||||
|
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||||
|
self.assertTrue(ok)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# disable TunableOp
|
# disable TunableOp
|
||||||
torch.cuda.tunable.enable(False)
|
torch.cuda.tunable.enable(False)
|
||||||
@ -4780,12 +4810,14 @@ class TestLinalg(TestCase):
|
|||||||
# Scaled GEMM parameters
|
# Scaled GEMM parameters
|
||||||
fillA = 0.25
|
fillA = 0.25
|
||||||
fillB = 0.75
|
fillB = 0.75
|
||||||
m = n = k = 16
|
n = 16
|
||||||
|
m = 32
|
||||||
|
k = 64
|
||||||
scaleA = torch.tensor(0.8, device=device)
|
scaleA = torch.tensor(0.8, device=device)
|
||||||
scaleB = torch.tensor(0.9, device=device)
|
scaleB = torch.tensor(0.9, device=device)
|
||||||
|
|
||||||
dtypeA = dtypeB = dtype
|
dtypeA = dtypeB = dtype
|
||||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||||
|
|
||||||
# Summary of bias types that are supported:
|
# Summary of bias types that are supported:
|
||||||
@ -4813,7 +4845,7 @@ class TestLinalg(TestCase):
|
|||||||
# rowwise scaling, only supported for this dtype combination
|
# rowwise scaling, only supported for this dtype combination
|
||||||
if dtype is torch.torch.float8_e4m3fnuz:
|
if dtype is torch.torch.float8_e4m3fnuz:
|
||||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||||
@ -4850,6 +4882,10 @@ class TestLinalg(TestCase):
|
|||||||
self.assertTrue(os.path.exists(result_filename))
|
self.assertTrue(os.path.exists(result_filename))
|
||||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||||
|
|
||||||
|
# Compare Param Signature of untuned and tuned results
|
||||||
|
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||||
|
self.assertTrue(ok)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# disable TunableOp
|
# disable TunableOp
|
||||||
torch.cuda.tunable.enable(False)
|
torch.cuda.tunable.enable(False)
|
||||||
@ -5328,6 +5364,7 @@ class TestLinalg(TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
@skipCUDAIfNotRocm
|
||||||
@dtypes(torch.bfloat16)
|
@dtypes(torch.bfloat16)
|
||||||
def test_gemm_bias_offline_tunableop(self, device, dtype):
|
def test_gemm_bias_offline_tunableop(self, device, dtype):
|
||||||
# This test is the offline version of test_gemm_bias_tunableop
|
# This test is the offline version of test_gemm_bias_tunableop
|
||||||
@ -5390,6 +5427,10 @@ class TestLinalg(TestCase):
|
|||||||
self.assertTrue(os.path.exists(result_filename))
|
self.assertTrue(os.path.exists(result_filename))
|
||||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||||
|
|
||||||
|
# Compare Param Signature of untuned and tuned results
|
||||||
|
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||||
|
self.assertTrue(ok)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# disable TunableOp
|
# disable TunableOp
|
||||||
torch.cuda.tunable.enable(False)
|
torch.cuda.tunable.enable(False)
|
||||||
@ -5435,12 +5476,14 @@ class TestLinalg(TestCase):
|
|||||||
# Scaled GEMM parameters
|
# Scaled GEMM parameters
|
||||||
fillA = 0.25
|
fillA = 0.25
|
||||||
fillB = 0.75
|
fillB = 0.75
|
||||||
m = n = k = 32
|
n = 32
|
||||||
|
m = 64
|
||||||
|
k = 128
|
||||||
scaleA = torch.tensor(0.8, device=device)
|
scaleA = torch.tensor(0.8, device=device)
|
||||||
scaleB = torch.tensor(0.9, device=device)
|
scaleB = torch.tensor(0.9, device=device)
|
||||||
|
|
||||||
dtypeA = dtypeB = dtype
|
dtypeA = dtypeB = dtype
|
||||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||||
|
|
||||||
# Summary of bias types that are supported:
|
# Summary of bias types that are supported:
|
||||||
@ -5468,7 +5511,7 @@ class TestLinalg(TestCase):
|
|||||||
# rowwise scaling, only supported for this dtype combination
|
# rowwise scaling, only supported for this dtype combination
|
||||||
if dtype is torch.torch.float8_e4m3fnuz:
|
if dtype is torch.torch.float8_e4m3fnuz:
|
||||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
# This stores total number of cummulative results
|
# This stores total number of cummulative results
|
||||||
@ -5645,6 +5688,16 @@ class TestLinalg(TestCase):
|
|||||||
'nn_41_41_41_ld_41_41_41')
|
'nn_41_41_41_ld_41_41_41')
|
||||||
self.assertTrue(found_result is not None)
|
self.assertTrue(found_result is not None)
|
||||||
|
|
||||||
|
self.assertTrue(torch.cuda.tunable.write_file())
|
||||||
|
|
||||||
|
# Make sure the results file exists and that it is not zero
|
||||||
|
self.assertTrue(os.path.exists(result_filename))
|
||||||
|
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||||
|
|
||||||
|
# Compare Param Signature of untuned and tuned results
|
||||||
|
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||||
|
self.assertTrue(ok)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Disable TF32
|
# Disable TF32
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -478,10 +478,11 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
else: # ScaledGEMM
|
else: # ScaledGEMM
|
||||||
|
count = untuned_gemm[0].count("_")
|
||||||
|
assert count in [6, 7]
|
||||||
untuned_gemm_temp = untuned_gemm[0].split("_")
|
untuned_gemm_temp = untuned_gemm[0].split("_")
|
||||||
# dtypeC = might not be FP8 type, keep track
|
# dtypeC = might not be FP8 type, keep track
|
||||||
# of the the number of underscores
|
# of the the number of underscores
|
||||||
count = untuned_gemm_temp.count("_")
|
|
||||||
op_sig = untuned_gemm_temp[0]
|
op_sig = untuned_gemm_temp[0]
|
||||||
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
|
data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
|
||||||
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
|
data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
|
||||||
@ -497,6 +498,23 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
|||||||
|
|
||||||
untuned_gemm_temp = untuned_gemm[1].split("_")
|
untuned_gemm_temp = untuned_gemm[1].split("_")
|
||||||
[n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
|
[n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
|
||||||
|
if op_sig == "GemmStridedBatchedTunableOp":
|
||||||
|
assert untuned_gemm_temp[6] == "ld"
|
||||||
|
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
|
||||||
|
else:
|
||||||
|
assert untuned_gemm_temp[4] == "ld"
|
||||||
|
[ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]
|
||||||
|
|
||||||
|
# We cannot handle submatrices in offline tuning
|
||||||
|
if all(item in [n, m, k] for item in [lda, ldb, ldc]):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"Offline tuning is not supported on submatrices. Use online tuning instead. "
|
||||||
|
+ f"Skipped tuning for: {untuned_gemm[1]}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if op_sig == "GemmTunableOp":
|
if op_sig == "GemmTunableOp":
|
||||||
matA = (
|
matA = (
|
||||||
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
||||||
@ -525,6 +543,10 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
|||||||
matB = matB.transpose(1, 2) if transA else matB
|
matB = matB.transpose(1, 2) if transA else matB
|
||||||
torch.bmm(matA, matB)
|
torch.bmm(matA, matB)
|
||||||
elif op_sig == "ScaledGemmTunableOp":
|
elif op_sig == "ScaledGemmTunableOp":
|
||||||
|
# Only combination supported by PyTorch
|
||||||
|
assert transA is True
|
||||||
|
assert transB is False
|
||||||
|
|
||||||
fillA = 0.25
|
fillA = 0.25
|
||||||
fillB = 0.75
|
fillB = 0.75
|
||||||
matA = (
|
matA = (
|
||||||
@ -533,9 +555,9 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
|||||||
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
|
else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
|
||||||
)
|
)
|
||||||
matB = (
|
matB = (
|
||||||
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid)
|
torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
|
||||||
if transA
|
if transA
|
||||||
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid).t()
|
else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert untuned_gemm_temp[8] == "rw"
|
assert untuned_gemm_temp[8] == "rw"
|
||||||
@ -544,8 +566,16 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
|||||||
else:
|
else:
|
||||||
rowwise = False
|
rowwise = False
|
||||||
if rowwise:
|
if rowwise:
|
||||||
scaleA = torch.ones((matA.shape[0], 1), device=deviceid)
|
scaleA = (
|
||||||
scaleB = torch.ones((1, matB.shape[0]), device=deviceid)
|
torch.ones((1, m), device=deviceid)
|
||||||
|
if transB
|
||||||
|
else torch.ones((m, 1), device=deviceid)
|
||||||
|
)
|
||||||
|
scaleB = (
|
||||||
|
torch.ones((1, n), device=deviceid)
|
||||||
|
if transA
|
||||||
|
else torch.ones((n, 1), device=deviceid)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
scaleA = torch.tensor(0.8, device=deviceid)
|
scaleA = torch.tensor(0.8, device=deviceid)
|
||||||
scaleB = torch.tensor(0.9, device=deviceid)
|
scaleB = torch.tensor(0.9, device=deviceid)
|
||||||
|
Reference in New Issue
Block a user