[ROCm] Fix unit test: matmul_offline_mgpu_gpu_tunableop (#142269)

Fixes #141652

This PR fixes (at least in part) the unit test failure. However, we may also need to do a separate flush of the untuned results-- if this test continues to be flaky, another PR would be needed to flush the untuned results as well.

Tested locally and it seems to be working.

Also fixing code that was accidentally commented out code in the unit test from the prior multi-gpu offline tuning PR https://github.com/pytorch/pytorch/pull/139673

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142269
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero
2024-12-08 02:18:00 +00:00
committed by PyTorch MergeBot
parent b1bb860d3c
commit 2fc8bac091
2 changed files with 11 additions and 3 deletions

View File

@ -4753,9 +4753,9 @@ class TestLinalg(TestCase):
self.assertTrue(os.path.exists(result_filename))
# Check the full results files was written, one per gpu
# for i in range(total_gpus):
# result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
# self.assertTrue(os.path.exists(result_full_filename))
for i in range(total_gpus):
result_full_filename = os.path.join(tmp_dir, f"tunableop_results_full{i}.csv")
self.assertTrue(os.path.exists(result_full_filename))
finally:
# disables TunableOp

View File

@ -488,6 +488,7 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
checks = [] # empty list to hold futures
futures = [] # empty list to hold futures
flush_results = [] # empty list to hold futures
# GEMM are assigned to GPUs in a round robin manner
h = 0
@ -515,6 +516,13 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
for future in concurrent.futures.as_completed(futures):
future.result()
for g in range(num_gpus):
flush_result = executor.submit(write_file)
flush_results.append(flush_result)
for flush_result in concurrent.futures.as_completed(flush_results):
flush_result.result()
torch.cuda.synchronize()
_gather_tunableop_results()