mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][TunableOp] Fix offline tuning for ScaledGEMM. (#149677)
The main purpose of this PR is to fix offline tuning for ScaledGEMM. The previous UT passed because it was not strict enough. Additionally: - All the offline tuning tests now do a comparison with the online results to ensure that ParamSignature match. - We raise an error if submatrices are encountered as this is only supported in online tuning mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149677 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9a5e1d038
commit
01b1d1f91b
@ -107,6 +107,31 @@ def find_tunableop_result(results, OpSig, ParamSig):
|
||||
return inner_tuple
|
||||
return None
|
||||
|
||||
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
|
||||
# Compare Param Signature of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Param Signature in the untuned file
|
||||
# there is a matching one in the tuned results file.
|
||||
import csv
|
||||
ok = False
|
||||
with open(untuned_filename) as file1:
|
||||
with open(tuned_filename) as file2:
|
||||
untuned_reader = csv.reader(file1)
|
||||
untuned_csv_entries = [row[1] for row in untuned_reader]
|
||||
|
||||
tuned_reader = csv.reader(file2)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(tuned_reader, None)
|
||||
|
||||
result_csv_entries = [row[1] for row in tuned_reader]
|
||||
|
||||
for value in untuned_csv_entries:
|
||||
if value in result_csv_entries:
|
||||
ok = True
|
||||
else:
|
||||
ok = False
|
||||
|
||||
return ok
|
||||
|
||||
class TestLinalg(TestCase):
|
||||
@contextlib.contextmanager
|
||||
def _hip_allow_tf32(self):
|
||||
@ -4680,6 +4705,7 @@ class TestLinalg(TestCase):
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.half)
|
||||
def test_matmul_offline_tunableop(self, device, dtype):
|
||||
import os
|
||||
@ -4733,6 +4759,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -4780,12 +4810,14 @@ class TestLinalg(TestCase):
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
m = n = k = 16
|
||||
n = 16
|
||||
m = 32
|
||||
k = 64
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
dtypeA = dtypeB = dtype
|
||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
||||
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||
|
||||
# Summary of bias types that are supported:
|
||||
@ -4813,7 +4845,7 @@ class TestLinalg(TestCase):
|
||||
# rowwise scaling, only supported for this dtype combination
|
||||
if dtype is torch.torch.float8_e4m3fnuz:
|
||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
@ -4850,6 +4882,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -5328,6 +5364,7 @@ class TestLinalg(TestCase):
|
||||
pass
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_gemm_bias_offline_tunableop(self, device, dtype):
|
||||
# This test is the offline version of test_gemm_bias_tunableop
|
||||
@ -5390,6 +5427,10 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# disable TunableOp
|
||||
torch.cuda.tunable.enable(False)
|
||||
@ -5435,12 +5476,14 @@ class TestLinalg(TestCase):
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
m = n = k = 32
|
||||
n = 32
|
||||
m = 64
|
||||
k = 128
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
dtypeA = dtypeB = dtype
|
||||
matA = torch.full((k, m), fillA, dtype=dtypeA, device=device)
|
||||
matA = torch.full((m, k), fillA, dtype=dtypeA, device=device)
|
||||
matB = torch.full((n, k), fillB, dtype=dtypeB, device=device).t()
|
||||
|
||||
# Summary of bias types that are supported:
|
||||
@ -5468,7 +5511,7 @@ class TestLinalg(TestCase):
|
||||
# rowwise scaling, only supported for this dtype combination
|
||||
if dtype is torch.torch.float8_e4m3fnuz:
|
||||
scaleA = torch.ones((matA.shape[0], 1), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[0]), device=device)
|
||||
scaleB = torch.ones((1, matB.shape[1]), device=device)
|
||||
torch._scaled_mm(matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=torch.bfloat16)
|
||||
|
||||
# This stores total number of cummulative results
|
||||
@ -5645,6 +5688,16 @@ class TestLinalg(TestCase):
|
||||
'nn_41_41_41_ld_41_41_41')
|
||||
self.assertTrue(found_result is not None)
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
# Disable TF32
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
Reference in New Issue
Block a user