mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][TunableOp] Stricter unit tests for online and offline tuning (#150142)
Improvements to unit tests and warnings for unsupported cases in offline tuning. Here are more details: - Previously we only compared the OpSig for the untuned vs. tuned entries. This was not strict enough so we now compare OpSig+ParamSig. - The main offline and online UTs are now stricter to make sure we exercise the code paths for the four combinations of transA and transB. - Offline tuning does not support some tensor shapes. Emit warning and skip tuning. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150142 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
157bff22f7
commit
ca2ffc23ab
@ -109,28 +109,29 @@ def find_tunableop_result(results, OpSig, ParamSig):
|
||||
return inner_tuple
|
||||
return None
|
||||
|
||||
def compare_untuned_tuned_param_sig(untuned_filename, tuned_filename):
|
||||
# Compare Param Signature of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Param Signature in the untuned file
|
||||
def compare_untuned_tuned_entries(untuned_filename, tuned_filename):
|
||||
# Compare the entries of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Op+Param Signature in the untuned file
|
||||
# there is a matching one in the tuned results file.
|
||||
import csv
|
||||
ok = False
|
||||
with open(untuned_filename) as file1:
|
||||
with open(tuned_filename) as file2:
|
||||
untuned_reader = csv.reader(file1)
|
||||
untuned_csv_entries = [row[1] for row in untuned_reader]
|
||||
untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader}
|
||||
|
||||
tuned_reader = csv.reader(file2)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(tuned_reader, None)
|
||||
|
||||
result_csv_entries = [row[1] for row in tuned_reader]
|
||||
result_csv_entries = {(row[0], row[1]) for row in tuned_reader}
|
||||
|
||||
for value in untuned_csv_entries:
|
||||
if value in result_csv_entries:
|
||||
ok = True
|
||||
else:
|
||||
ok = False
|
||||
missing = untuned_csv_entries - result_csv_entries
|
||||
|
||||
if missing:
|
||||
ok = False
|
||||
else:
|
||||
ok = True
|
||||
|
||||
return ok
|
||||
|
||||
@ -4676,12 +4677,13 @@ class TestLinalg(TestCase):
|
||||
# disable tunableop buffer rotation for all tests everywhere, it can be slow
|
||||
# We set the TunableOp numerical check environment variable here because it is
|
||||
# possible to hit some invalid numerical solutions due to the small matrix sizes.
|
||||
import os
|
||||
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
if dtype is torch.half:
|
||||
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
|
||||
# Numerical check adds significant overhead, unsure if this is needed
|
||||
# or if there was a transiet problem at the time.
|
||||
# if dtype is torch.half:
|
||||
# os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
|
||||
ordinal = torch.cuda.current_device()
|
||||
|
||||
# set these to single iterations to keep it short but still exercise the code
|
||||
@ -4689,8 +4691,9 @@ class TestLinalg(TestCase):
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
|
||||
# Using gen_sizes_matmul(2) to ensure we cover
|
||||
# 'NN', 'TN', 'TT', and 'NN' cases.
|
||||
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
|
||||
x = make_arg(size_x, noncontiguous=nctg_x)
|
||||
y = make_arg(size_y, noncontiguous=nctg_y)
|
||||
self.check_single_matmul(x, y)
|
||||
@ -4722,9 +4725,27 @@ class TestLinalg(TestCase):
|
||||
@dtypes(torch.half)
|
||||
def test_matmul_offline_tunableop(self, device, dtype):
|
||||
# Main offline tunableop test
|
||||
# Tests only the main matmul GEMM API
|
||||
# NOTE: The offline tuning does not support certain tensor
|
||||
# shapes as noted below. Submatrics / matrix slices are
|
||||
# not supported at all.
|
||||
import os
|
||||
|
||||
def has_any_dim_size_one(tensor: torch.Tensor):
|
||||
"""Check if any dimension of a PyTorch tensor has size 1."""
|
||||
return any(dim == 1 for dim in tensor.shape)
|
||||
|
||||
def is_mm_compatible(A, B):
|
||||
"""Check if two matrices A and B are compatible for torch.mm."""
|
||||
return A.dim() == 2 and B.dim() == 2 and A.shape[1] == B.shape[0]
|
||||
|
||||
def is_bmm_compatible(A, B):
|
||||
"""Check if two 3D tensors are compatible for torch.bmm."""
|
||||
return (
|
||||
A.dim() == 3 and B.dim() == 3 and
|
||||
A.shape[0] == B.shape[0] and # Batch size must match
|
||||
A.shape[2] == B.shape[1] # Inner dimensions must align
|
||||
)
|
||||
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
@ -4737,10 +4758,65 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(torch.cuda.tunable.record_untuned_is_enabled())
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
|
||||
x = make_arg(size_x, noncontiguous=nctg_x)
|
||||
y = make_arg(size_y, noncontiguous=nctg_y)
|
||||
self.check_single_matmul(x, y)
|
||||
# offline tuning only handles matmuls on two dimensionsal tensors
|
||||
# matmul that require broadcasting are
|
||||
# not supported either.
|
||||
# Below we check the different transA and transB combinations.
|
||||
for (size_x, size_y) in self.gen_sizes_matmul(x_dim=2, y_dim=2, matrix_size=4):
|
||||
x = make_arg(size_x, noncontiguous=False)
|
||||
y = make_arg(size_y, noncontiguous=False)
|
||||
|
||||
if is_mm_compatible(x, y):
|
||||
self.check_single_matmul(x, y)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_mm_compatible(x.t(), y):
|
||||
self.check_single_matmul(x.t(), y)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_mm_compatible(x, y.t()):
|
||||
self.check_single_matmul(x, y.t())
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_mm_compatible(x.t(), y.t()):
|
||||
self.check_single_matmul(x.t(), y.t())
|
||||
else:
|
||||
continue
|
||||
|
||||
# offline tuning only handles batched matmuls on
|
||||
# three dimensionsal tensors
|
||||
# matmul that require broadcasting are
|
||||
# not supported either.
|
||||
# Below we check the different transA and transB combinations.
|
||||
for (size_x, size_y) in self.gen_sizes_matmul(x_dim=3, y_dim=3, matrix_size=4):
|
||||
x = make_arg(size_x, noncontiguous=False)
|
||||
y = make_arg(size_y, noncontiguous=False)
|
||||
|
||||
if has_any_dim_size_one(x) or has_any_dim_size_one(y):
|
||||
continue
|
||||
|
||||
if is_bmm_compatible(x, y):
|
||||
self.check_single_matmul(x, y)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_bmm_compatible(x.transpose(1, 2), y):
|
||||
self.check_single_matmul(x.transpose(1, 2), y)
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_bmm_compatible(x, y.transpose(1, 2)):
|
||||
self.check_single_matmul(x, y.transpose(1, 2))
|
||||
else:
|
||||
continue
|
||||
|
||||
if is_bmm_compatible(x.transpose(1, 2), y.transpose(1, 2)):
|
||||
self.check_single_matmul(x.transpose(1, 2), y.transpose(1, 2))
|
||||
else:
|
||||
continue
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False)
|
||||
@ -4768,7 +4844,7 @@ class TestLinalg(TestCase):
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
@onlyCUDA
|
||||
@ -4866,7 +4942,7 @@ class TestLinalg(TestCase):
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
|
||||
@ -5282,7 +5358,7 @@ class TestLinalg(TestCase):
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
@onlyCUDA
|
||||
@ -5498,7 +5574,7 @@ class TestLinalg(TestCase):
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_param_sig(untuned_filename, result_filename)
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
|
||||
@ -516,6 +516,19 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
return
|
||||
|
||||
if op_sig == "GemmTunableOp":
|
||||
# Warnings for unsupported cases:
|
||||
if m == 1 or n == 1 or k == 1:
|
||||
if (not transA) and (not transB):
|
||||
pass # case is supported
|
||||
elif transB and n == 1:
|
||||
pass # case is supported
|
||||
else:
|
||||
warnings.warn(
|
||||
"Offline tuning is not supported for this GEMM. Use online tuning instead. "
|
||||
+ f"Skipped tuning for: {untuned_gemm[1]}"
|
||||
)
|
||||
return
|
||||
|
||||
matA = (
|
||||
torch.rand(k, m, dtype=dtype, device=deviceid).t()
|
||||
if transB
|
||||
@ -528,6 +541,14 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
|
||||
)
|
||||
torch.mm(matA, matB)
|
||||
elif op_sig == "GemmStridedBatchedTunableOp":
|
||||
# Warnings for unsupported cases:
|
||||
if m == 1 or n == 1 or k == 1:
|
||||
warnings.warn(
|
||||
"Offline tuning is not support for this GEMM. Use online tuning instead. "
|
||||
+ f"Skipped tuning for: {untuned_gemm[1]}"
|
||||
)
|
||||
return
|
||||
|
||||
[b] = [int(g) for g in untuned_gemm_temp[5:6]]
|
||||
matA = (
|
||||
torch.rand(b, k, m, dtype=dtype, device=deviceid)
|
||||
|
||||
Reference in New Issue
Block a user