mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][TunableOp] Fix UT race condition and reduce UT duration. (#150463)
This PR fixes two race conditions that occur when UT tests are run: - In a particular order within a single shard. - Concurrently in multiple shards. Each test now gets a unique filename that depends on the test name. There were two other minor improvements to the UTs: - matmul_offline_mgpu could occasionally fail if run on 8 GPUs. Criteria was relaxed. - bmm_tunableop_rocm checks that the rotating buffer is not zero. Otherwise, the test is not useful. Additionally, several UTs took over 1 minute to run. Their duration was reduced by a combination of setting max tuning iterations to one, setting the rotating buffer size to zero, and/or reducing the matrix dimensions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150463 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
1bc2b2b12a
commit
d0026fa138
@ -65,22 +65,7 @@ def blaslt_supported_device():
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_tunableop_defaults():
|
||||
if not torch.cuda.is_available():
|
||||
# TunableOp not supported on CPU at this time.
|
||||
return
|
||||
|
||||
# disable TunableOp and restore to default values
|
||||
torch.cuda.tunable.enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(False)
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
torch.cuda.tunable.set_max_tuning_duration(30)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(100)
|
||||
torch.cuda.tunable.set_rotating_buffer_size(-1)
|
||||
ordinal = torch.cuda.current_device()
|
||||
torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv")
|
||||
|
||||
def tunableop_matmul(device, dtype, offline=False):
|
||||
def tunableop_matmul(device, dtype, result_filename=None, offline=False):
|
||||
# Helper function to test TunableOp in a subprocess
|
||||
# requires helper function since lambda function
|
||||
# not supported by multiprocessing module
|
||||
@ -90,6 +75,9 @@ def tunableop_matmul(device, dtype, offline=False):
|
||||
if offline:
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
else:
|
||||
if result_filename is not None:
|
||||
torch.cuda.tunable.set_filename(result_filename)
|
||||
|
||||
torch.cuda.tunable.set_max_tuning_duration(1)
|
||||
A = torch.randn((17, 17), device=device, dtype=dtype)
|
||||
@ -109,31 +97,13 @@ def find_tunableop_result(results, OpSig, ParamSig):
|
||||
return inner_tuple
|
||||
return None
|
||||
|
||||
def compare_untuned_tuned_entries(untuned_filename, tuned_filename):
|
||||
# Compare the entries of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Op+Param Signature in the untuned file
|
||||
# there is a matching one in the tuned results file.
|
||||
import csv
|
||||
ok = False
|
||||
with open(untuned_filename) as file1:
|
||||
with open(tuned_filename) as file2:
|
||||
untuned_reader = csv.reader(file1)
|
||||
untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader}
|
||||
|
||||
tuned_reader = csv.reader(file2)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(tuned_reader, None)
|
||||
|
||||
result_csv_entries = {(row[0], row[1]) for row in tuned_reader}
|
||||
|
||||
missing = untuned_csv_entries - result_csv_entries
|
||||
|
||||
if missing:
|
||||
ok = False
|
||||
else:
|
||||
ok = True
|
||||
|
||||
return ok
|
||||
def get_tunableop_untuned_filename():
|
||||
import os
|
||||
ordinal = torch.cuda.current_device()
|
||||
untuned_filename_env = os.getenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME")
|
||||
untuned_filename_base, _, _ = untuned_filename_env.rpartition('.')
|
||||
untuned_filename = f"{untuned_filename_base}{ordinal}.csv"
|
||||
return untuned_filename
|
||||
|
||||
class TestLinalg(TestCase):
|
||||
@contextlib.contextmanager
|
||||
@ -165,7 +135,7 @@ class TestLinalg(TestCase):
|
||||
# Inialize and then tear down TunableOp
|
||||
import glob
|
||||
import os
|
||||
set_tunableop_defaults()
|
||||
self._set_tunableop_defaults()
|
||||
torch.cuda.tunable.enable(True)
|
||||
|
||||
try:
|
||||
@ -175,7 +145,13 @@ class TestLinalg(TestCase):
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
# clean up, remove any files that were generated
|
||||
for file in glob.glob("tunableop*.csv"):
|
||||
results_filename = torch.cuda.tunable.get_filename()
|
||||
results_filename_pattern, _, _ = results_filename.rpartition('.')
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
untuned_filename_pattern, _, _ = untuned_filename.rpartition('.')
|
||||
patterns = [f"{results_filename_pattern[:-1]}*.csv", f"{untuned_filename_pattern[:-1]}*.csv"]
|
||||
files = [f for pattern in patterns for f in glob.glob(pattern)]
|
||||
for file in files:
|
||||
try:
|
||||
os.remove(file)
|
||||
# NB: The file is locked on Windows
|
||||
@ -194,6 +170,59 @@ class TestLinalg(TestCase):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _set_tunableop_defaults(self):
|
||||
if not torch.cuda.is_available():
|
||||
# TunableOp not supported on CPU at this time.
|
||||
return
|
||||
|
||||
# disable TunableOp and restore to default values
|
||||
torch.cuda.tunable.enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(False)
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
torch.cuda.tunable.set_max_tuning_duration(30)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(100)
|
||||
torch.cuda.tunable.set_rotating_buffer_size(-1)
|
||||
ordinal = torch.cuda.current_device()
|
||||
|
||||
# Set filenames to be unique on a per test basis
|
||||
import os
|
||||
unique_id = self.id().split(".")[-1]
|
||||
torch.cuda.tunable.set_filename(f"tunableop_results_{unique_id}_{ordinal}.csv")
|
||||
# ordinal gets automatically appended
|
||||
os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"] = f"tunableop_untuned_{unique_id}_.csv"
|
||||
|
||||
def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=None):
|
||||
# Compare the entries of untuned and tuned Tunableop results
|
||||
# file. Verify that for each Op+Param Signature in the untuned file
|
||||
# there is a matching one in the tuned results file.
|
||||
import csv
|
||||
ok = False
|
||||
ordinal = torch.cuda.current_device()
|
||||
if untuned_filename is None:
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
if tuned_filename is None:
|
||||
tuned_filename = torch.cuda.tunable.get_filename()
|
||||
|
||||
with open(untuned_filename) as file1:
|
||||
with open(tuned_filename) as file2:
|
||||
untuned_reader = csv.reader(file1)
|
||||
untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader}
|
||||
|
||||
tuned_reader = csv.reader(file2)
|
||||
for _ in range(5): # Skip the first 5 lines for the validator
|
||||
next(tuned_reader, None)
|
||||
|
||||
result_csv_entries = {(row[0], row[1]) for row in tuned_reader}
|
||||
|
||||
missing = untuned_csv_entries - result_csv_entries
|
||||
|
||||
if missing:
|
||||
ok = False
|
||||
else:
|
||||
ok = True
|
||||
|
||||
return ok
|
||||
|
||||
exact_dtype = True
|
||||
|
||||
@dtypes(torch.float, torch.cfloat)
|
||||
@ -4693,16 +4722,18 @@ class TestLinalg(TestCase):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
# Using gen_sizes_matmul(2) to ensure we cover
|
||||
# 'NN', 'TN', 'TT', and 'NN' cases.
|
||||
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
|
||||
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2, y_dim=3),
|
||||
(True, False), (True, False)):
|
||||
x = make_arg(size_x, noncontiguous=nctg_x)
|
||||
y = make_arg(size_y, noncontiguous=nctg_y)
|
||||
self.check_single_matmul(x, y)
|
||||
|
||||
filename1 = torch.cuda.tunable.get_filename()
|
||||
filename2 = "tunableop_results_tmp1.csv"
|
||||
filename3 = "tunableop_results_tmp2.csv"
|
||||
unique_id = self.id().split(".")[-1]
|
||||
filename2 = f"{filename1}_tmp1.csv"
|
||||
filename3 = f"{filename1}_tmp2.csv"
|
||||
ordinal = torch.cuda.current_device()
|
||||
assert filename1 == f"tunableop_results{ordinal}.csv"
|
||||
assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv"
|
||||
assert len(torch.cuda.tunable.get_results()) > 0
|
||||
|
||||
assert torch.cuda.tunable.write_file() # use default filename
|
||||
@ -4720,6 +4751,10 @@ class TestLinalg(TestCase):
|
||||
assert file1_contents == file2_contents
|
||||
assert file1_contents == file3_contents
|
||||
|
||||
# We need to reset the filename to the default value so we can properly
|
||||
# clean up intermediate files
|
||||
self._set_tunableop_defaults()
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.half)
|
||||
@ -4728,7 +4763,6 @@ class TestLinalg(TestCase):
|
||||
# NOTE: The offline tuning does not support certain tensor
|
||||
# shapes as noted below. Submatrics / matrix slices are
|
||||
# not supported at all.
|
||||
import os
|
||||
|
||||
def has_any_dim_size_one(tensor: torch.Tensor):
|
||||
"""Check if any dimension of a PyTorch tensor has size 1."""
|
||||
@ -4750,7 +4784,6 @@ class TestLinalg(TestCase):
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
ordinal = torch.cuda.current_device()
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
|
||||
# record GEMM
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
@ -4821,8 +4854,7 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False)
|
||||
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
|
||||
# tuning the untuned GEMMs in file
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
@ -4839,12 +4871,8 @@ class TestLinalg(TestCase):
|
||||
self.assertGreater(new_results - ref_results, 0)
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
ok = self._compare_untuned_tuned_entries()
|
||||
self.assertTrue(ok)
|
||||
|
||||
@onlyCUDA
|
||||
@ -4853,14 +4881,11 @@ class TestLinalg(TestCase):
|
||||
@dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
|
||||
def test_scaled_gemm_offline_tunableop(self, device, dtype):
|
||||
# This test is the offline version of test_scaled_gemm_tunableop
|
||||
import os
|
||||
|
||||
with self._tunableop_ctx():
|
||||
ordinal = torch.cuda.current_device()
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
|
||||
# record GEMM
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
@ -4910,8 +4935,7 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False)
|
||||
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
|
||||
# tuning the untuned GEMMs in file
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
@ -4937,12 +4961,8 @@ class TestLinalg(TestCase):
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
ok = self._compare_untuned_tuned_entries()
|
||||
self.assertTrue(ok)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
|
||||
@ -4960,7 +4980,11 @@ class TestLinalg(TestCase):
|
||||
total_gpus = torch.cuda.device_count()
|
||||
|
||||
ordinal = torch.cuda.current_device()
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
|
||||
# Untuned filename has unique id, but results file
|
||||
# does not because it is executed in a subprocess
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv")
|
||||
|
||||
# turn on untuned GEMM recording and turn off tuning
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
@ -4985,19 +5009,14 @@ class TestLinalg(TestCase):
|
||||
torch.cuda.tunable.mgpu_tune_gemm_in_file(untuned_filename, total_gpus)
|
||||
|
||||
# check the results files where written, one per gpu
|
||||
# get the size of the first result and make sure it
|
||||
# greater than 100. Since the validator text should
|
||||
# be at least that much.
|
||||
# The other results file will have
|
||||
# at least the size of the first results file - 80
|
||||
# Check that the results file is not empty and store
|
||||
# that in a local variable for the next loop.
|
||||
for i in range(total_gpus):
|
||||
result_filename = f"tunableop_results{i}.csv"
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
if i == 0: # Store for next loop
|
||||
result_size = os.path.getsize(result_filename)
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
self.assertGreater(os.path.getsize(result_filename), result_size - 80)
|
||||
|
||||
|
||||
# Check the full results files was written, one per gpu
|
||||
# check that the size of the full results file for
|
||||
@ -5018,6 +5037,7 @@ class TestLinalg(TestCase):
|
||||
def test_rotating_buffer_tunableop(self, device, dtype):
|
||||
# Test the TunableOp rotating buffer API
|
||||
# Test the default value, will return the l2_cache_size
|
||||
self._set_tunableop_defaults()
|
||||
l2_cache_size = torch.cuda.tunable.get_rotating_buffer_size()
|
||||
self.assertGreater(l2_cache_size, 0)
|
||||
# Test zero
|
||||
@ -5038,6 +5058,9 @@ class TestLinalg(TestCase):
|
||||
# buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_max_tuning_iterations(10)
|
||||
# Make sure the rotating buffer is not zero, otherwise this test does nothing useful.
|
||||
rotating_buffer = torch.cuda.tunable.get_rotating_buffer_size()
|
||||
self.assertGreater(rotating_buffer, 0)
|
||||
# the following 3 cases cover all previous failure cases and are here to catch regressions
|
||||
B = 16
|
||||
N = M = K = 256
|
||||
@ -5082,21 +5105,21 @@ class TestLinalg(TestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
@dtypes(torch.float)
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
|
||||
import os
|
||||
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
|
||||
# run operator first without tuning to ensure all rocm libs are loaded,
|
||||
# otherwise false positive mem leak
|
||||
B = 16
|
||||
N = M = K = 256
|
||||
dtype = torch.bfloat16
|
||||
B = 5
|
||||
N = M = K = 29
|
||||
device = torch.device("cuda:0")
|
||||
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
|
||||
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
|
||||
out = torch.bmm(i1, i2)
|
||||
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
# enable tunableop numeric check via env variable.
|
||||
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
|
||||
|
||||
@ -5213,9 +5236,9 @@ class TestLinalg(TestCase):
|
||||
ref_num_results = len(torch.cuda.tunable.get_results())
|
||||
|
||||
# Tune one GEMMs to make sure TunableOp is enabled
|
||||
M = 3
|
||||
N = 3
|
||||
K = 3
|
||||
M = 11
|
||||
N = 13
|
||||
K = 17
|
||||
A = torch.randn(N, K, device=device, dtype=dtype)
|
||||
B = torch.randn(K, M, device=device, dtype=dtype)
|
||||
C = torch.matmul(A, B)
|
||||
@ -5234,9 +5257,9 @@ class TestLinalg(TestCase):
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
# Try to tune one more GEMM
|
||||
M = 3
|
||||
N = 3
|
||||
K = 4
|
||||
M = 11
|
||||
N = 13
|
||||
K = 18
|
||||
A = torch.randn(N, K, device=device, dtype=dtype)
|
||||
B = torch.randn(K, M, device=device, dtype=dtype)
|
||||
C = torch.matmul(A, B)
|
||||
@ -5257,8 +5280,7 @@ class TestLinalg(TestCase):
|
||||
import multiprocessing as mp
|
||||
|
||||
with self._tunableop_ctx():
|
||||
ordinal = torch.cuda.current_device()
|
||||
filename = f"tunableop_results{ordinal}.csv"
|
||||
filename = torch.cuda.tunable.get_filename()
|
||||
|
||||
# force=True needed according to:
|
||||
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method
|
||||
@ -5266,7 +5288,7 @@ class TestLinalg(TestCase):
|
||||
# already set the start method
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype))
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, filename, False))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@ -5305,14 +5327,11 @@ class TestLinalg(TestCase):
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_gemm_bias_offline_tunableop(self, device, dtype):
|
||||
# This test is the offline version of test_gemm_bias_tunableop
|
||||
import os
|
||||
ordinal = torch.cuda.current_device()
|
||||
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
|
||||
# record GEMM
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
@ -5330,8 +5349,7 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue(torch.cuda.tunable.is_enabled())
|
||||
self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False)
|
||||
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
|
||||
# tuning the untuned GEMMs in file
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
@ -5353,12 +5371,8 @@ class TestLinalg(TestCase):
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
ok = self._compare_untuned_tuned_entries()
|
||||
self.assertTrue(ok)
|
||||
|
||||
@onlyCUDA
|
||||
@ -5378,6 +5392,7 @@ class TestLinalg(TestCase):
|
||||
# tested by PyTorch
|
||||
with self._tunableop_ctx():
|
||||
# set these to single iterations to keep it short but still exercise the code
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
|
||||
# Reference number of results
|
||||
@ -5386,9 +5401,9 @@ class TestLinalg(TestCase):
|
||||
# Scaled GEMM parameters
|
||||
fillA = 0.25
|
||||
fillB = 0.75
|
||||
n = 32
|
||||
m = 64
|
||||
k = 128
|
||||
n = 64
|
||||
m = 16
|
||||
k = 32
|
||||
scaleA = torch.tensor(0.8, device=device)
|
||||
scaleB = torch.tensor(0.9, device=device)
|
||||
|
||||
@ -5519,8 +5534,6 @@ class TestLinalg(TestCase):
|
||||
ordinal = torch.cuda.current_device()
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
|
||||
# record GEMM
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.record_untuned_enable(True)
|
||||
@ -5535,7 +5548,7 @@ class TestLinalg(TestCase):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
C = torch.matmul(A, B)
|
||||
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
self.assertTrue(os.path.exists(untuned_filename))
|
||||
|
||||
# tuning the untuned GEMMs in file
|
||||
@ -5569,12 +5582,8 @@ class TestLinalg(TestCase):
|
||||
|
||||
self.assertTrue(torch.cuda.tunable.write_file())
|
||||
|
||||
# Make sure the results file exists and that it is not zero
|
||||
self.assertTrue(os.path.exists(result_filename))
|
||||
self.assertGreater(os.path.getsize(result_filename), 0)
|
||||
|
||||
# Compare Param Signature of untuned and tuned results
|
||||
ok = compare_untuned_tuned_entries(untuned_filename, result_filename)
|
||||
ok = self._compare_untuned_tuned_entries()
|
||||
self.assertTrue(ok)
|
||||
|
||||
finally:
|
||||
@ -5606,10 +5615,11 @@ class TestLinalg(TestCase):
|
||||
with self._tunableop_ctx():
|
||||
os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1")
|
||||
|
||||
ordinal = torch.cuda.current_device()
|
||||
|
||||
result_filename = f"tunableop_results{ordinal}.csv"
|
||||
untuned_filename = f"tunableop_untuned{ordinal}.csv"
|
||||
# TunableOp is running in a subprocess
|
||||
# online tuning needs filename set through API
|
||||
# offline tuning needs filename set through environment variableq
|
||||
result_filename = torch.cuda.tunable.get_filename()
|
||||
untuned_filename = get_tunableop_untuned_filename()
|
||||
|
||||
# Offline Tuning case in a subprocess
|
||||
|
||||
@ -5619,7 +5629,7 @@ class TestLinalg(TestCase):
|
||||
# already set the start method
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, True))
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, None, True))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@ -5646,7 +5656,7 @@ class TestLinalg(TestCase):
|
||||
# already set the start method
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, False))
|
||||
p = mp.Process(target=tunableop_matmul, args=(device, dtype, result_filename, False))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@ -6868,7 +6878,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
@bf32_on_and_off(0.05)
|
||||
def test_addmm_relu_tunableop_rocm(self, device, dtype):
|
||||
with self._tunableop_ctx():
|
||||
torch.cuda.tunable.set_max_tuning_iterations(10)
|
||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(1)
|
||||
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user