[ROCm][TunableOp] Stricter unit tests for online and offline tuning (#150142)

Improvements to unit tests and warnings for unsupported cases in offline tuning. Here are more details:
- Previously we only compared the OpSig for the untuned vs. tuned entries. This was not strict enough so we now compare OpSig+ParamSig.
- The main offline and online UTs are now stricter to make sure we exercise the code paths for the four combinations of transA and transB.
- Offline tuning does not support some tensor shapes. Emit warning and skip tuning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150142
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Nichols A. Romero
2025-03-31 04:12:08 +00:00
committed by PyTorch MergeBot
parent 157bff22f7
commit ca2ffc23ab
2 changed files with 121 additions and 24 deletions

View File

@ -109,28 +109,29 @@ def find_tunableop_result(results, OpSig, ParamSig):
return inner_tuple
return None
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
# Compare Param Signature of untuned and tuned Tunableop results
# file. Verify that for each Param Signature in the untuned file
def compare_untuned_tuned_entries(untuned_filename, tuned_filename):
# Compare the entries of untuned and tuned Tunableop results
# file. Verify that for each Op+Param Signature in the untuned file
# there is a matching one in the tuned results file.
import csv
ok = False
with open(untuned_filename) as file1:
with open(tuned_filename) as file2:
untuned_reader = csv.reader(file1)
untuned_csv_entries = [row[1] for row in untuned_reader]
untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader}
tuned_reader = csv.reader(file2)
for _ in range(5): # Skip the first 5 lines for the validator
next(tuned_reader, None)
result_csv_entries = [row[1] for row in tuned_reader]
result_csv_entries = {(row[0], row[1]) for row in tuned_reader}
for value in untuned_csv_entries:
if value in result_csv_entries:
ok = True
else:
ok = False
missing = untuned_csv_entries - result_csv_entries
if missing:
ok = False
else:
ok = True
return ok
@ -4676,12 +4677,13 @@ class TestLinalg(TestCase):
# disable tunableop buffer rotation for all tests everywhere, it can be slow
# We set the TunableOp numerical check environment variable here because it is
# possible to hit some invalid numerical solutions due to the small matrix sizes.
import os
with self._tunableop_ctx():
torch.cuda.tunable.set_rotating_buffer_size(0)
if dtype is torch.half:
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
# Numerical check adds significant overhead, unsure if this is needed
# or if there was a transiet problem at the time.
# if dtype is torch.half:
# os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
ordinal = torch.cuda.current_device()
# set these to single iterations to keep it short but still exercise the code
@ -4689,8 +4691,9 @@ class TestLinalg(TestCase):
torch.cuda.tunable.set_max_tuning_iterations(1)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
# Using gen_sizes_matmul(2) to ensure we cover
# 'NN', 'TN', 'TT', and 'NN' cases.
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
x = make_arg(size_x, noncontiguous=nctg_x)
y = make_arg(size_y, noncontiguous=nctg_y)
self.check_single_matmul(x, y)
@ -4722,9 +4725,27 @@ class TestLinalg(TestCase):
@dtypes(torch.half)
def test_matmul_offline_tunableop(self, device, dtype):
# Main offline tunableop test
# Tests only the main matmul GEMM API
# NOTE: The offline tuning does not support certain tensor
# shapes as noted below. Submatrics / matrix slices are
# not supported at all.
import os
def has_any_dim_size_one(tensor: torch.Tensor):
"""Check if any dimension of a PyTorch tensor has size 1."""
return any(dim == 1 for dim in tensor.shape)
def is_mm_compatible(A, B):
"""Check if two matrices A and B are compatible for torch.mm."""
return A.dim() == 2 and B.dim() == 2 and A.shape[1] == B.shape[0]
def is_bmm_compatible(A, B):
"""Check if two 3D tensors are compatible for torch.bmm."""
return (
A.dim() == 3 and B.dim() == 3 and
A.shape[0] == B.shape[0] and # Batch size must match
A.shape[2] == B.shape[1] # Inner dimensions must align
)
with self._tunableop_ctx():
torch.cuda.tunable.set_rotating_buffer_size(0)
@ -4737,10 +4758,65 @@ class TestLinalg(TestCase):
self.assertTrue(torch.cuda.tunable.record_untuned_is_enabled())
make_arg = partial(make_tensor, device=device, dtype=dtype)
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
x = make_arg(size_x, noncontiguous=nctg_x)
y = make_arg(size_y, noncontiguous=nctg_y)
self.check_single_matmul(x, y)
# offline tuning only handles matmuls on two dimensionsal tensors
# matmul that require broadcasting are
# not supported either.
# Below we check the different transA and transB combinations.
for (size_x, size_y) in self.gen_sizes_matmul(x_dim=2, y_dim=2, matrix_size=4):
x = make_arg(size_x, noncontiguous=False)
y = make_arg(size_y, noncontiguous=False)
if is_mm_compatible(x, y):
self.check_single_matmul(x, y)
else:
continue
if is_mm_compatible(x.t(), y):
self.check_single_matmul(x.t(), y)
else:
continue
if is_mm_compatible(x, y.t()):
self.check_single_matmul(x, y.t())
else:
continue
if is_mm_compatible(x.t(), y.t()):
self.check_single_matmul(x.t(), y.t())
else:
continue
# offline tuning only handles batched matmuls on
# three dimensionsal tensors
# matmul that require broadcasting are
# not supported either.
# Below we check the different transA and transB combinations.
for (size_x, size_y) in self.gen_sizes_matmul(x_dim=3, y_dim=3, matrix_size=4):
x = make_arg(size_x, noncontiguous=False)
y = make_arg(size_y, noncontiguous=False)
if has_any_dim_size_one(x) or has_any_dim_size_one(y):
continue
if is_bmm_compatible(x, y):
self.check_single_matmul(x, y)
else:
continue
if is_bmm_compatible(x.transpose(1, 2), y):
self.check_single_matmul(x.transpose(1, 2), y)
else:
continue
if is_bmm_compatible(x, y.transpose(1, 2)):
self.check_single_matmul(x, y.transpose(1, 2))
else:
continue
if is_bmm_compatible(x.transpose(1, 2), y.transpose(1, 2)):
self.check_single_matmul(x.transpose(1, 2), y.transpose(1, 2))
else:
continue
self.assertTrue(torch.cuda.tunable.is_enabled())
self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False)
@ -4768,7 +4844,7 @@ class TestLinalg(TestCase):
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
self.assertTrue(ok)
@onlyCUDA
@ -4866,7 +4942,7 @@ class TestLinalg(TestCase):
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
self.assertTrue(ok)
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
@ -5282,7 +5358,7 @@ class TestLinalg(TestCase):
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
self.assertTrue(ok)
@onlyCUDA
@ -5498,7 +5574,7 @@ class TestLinalg(TestCase):
self.assertGreater(os.path.getsize(result_filename), 0)
# Compare Param Signature of untuned and tuned results
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
self.assertTrue(ok)
finally:

View File

@ -516,6 +516,19 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
return
if op_sig == "GemmTunableOp":
# Warnings for unsupported cases:
if m == 1 or n == 1 or k == 1:
if (not transA) and (not transB):
pass # case is supported
elif transB and n == 1:
pass # case is supported
else:
warnings.warn(
"Offline tuning is not supported for this GEMM. Use online tuning instead. "
+ f"Skipped tuning for: {untuned_gemm[1]}"
)
return
matA = (
torch.rand(k, m, dtype=dtype, device=deviceid).t()
if transB
@ -528,6 +541,14 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
)
torch.mm(matA, matB)
elif op_sig == "GemmStridedBatchedTunableOp":
# Warnings for unsupported cases:
if m == 1 or n == 1 or k == 1:
warnings.warn(
"Offline tuning is not support for this GEMM. Use online tuning instead. "
+ f"Skipped tuning for: {untuned_gemm[1]}"
)
return
[b] = [int(g) for g in untuned_gemm_temp[5:6]]
matA = (
torch.rand(b, k, m, dtype=dtype, device=deviceid)