From 7f9b74549485bf48a9e5d68dc4cbcb96b01a33dd Mon Sep 17 00:00:00 2001 From: Sarthak Tandon Date: Wed, 15 Oct 2025 20:02:27 +0000 Subject: [PATCH] [ROCm][tunableop] Modified Online Tuning Mode to add Instant Logging (#163965) - Added instant logging in online tuning mode, so that each tuned GEMM is instantly written - Allows us to have saved tuning configs, in cases of crashes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163965 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily --- aten/src/ATen/cuda/tunable/README.md | 2 - aten/src/ATen/cuda/tunable/Tunable.cpp | 143 ++++++++++++++++--------- aten/src/ATen/cuda/tunable/Tunable.h | 18 +++- docs/source/cuda.tunable.md | 8 -- test/test_linalg.py | 123 +++++++++++++++++---- torch/_C/__init__.pyi.in | 2 - torch/csrc/cuda/Module.cpp | 48 --------- torch/cuda/tunable.py | 29 ----- 8 files changed, 209 insertions(+), 164 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index b30040b7e284..4816886ecc86 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -175,8 +175,6 @@ All python APIs exist in the `torch.cuda.tunable` module. | get_filename() -> str | | | get_results() -> Tuple[str, str, str, float] | | | get_validators() -> Tuple[str, str] | | -| write_file_on_exit(val: bool) -> None | Default is True. | -| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. | | mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. | diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 6b19a738ec4a..c4d5fa261fc2 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature, } void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { - std::scoped_lock l{lock_}; + bool is_new = false; + ResultEntry inserted = ResultEntry::Null(); - auto it = results_.find(op_signature); - if (it == results_.end()) { - it = results_.insert({op_signature, {}}).first; + // ---- mutate maps under results lock ---- + { + std::scoped_lock l{lock_}; + auto& km = results_[op_signature]; // creates if missing + is_new = (km.find(params_signature) == km.end()); + AddImpl(op_signature, params_signature, std::move(best), km); + if (is_new) { + inserted = km.at(params_signature); // snapshot for I/O after unlocking + } + } + if (!is_new) return; // only write once per unique (op, params) + + TuningContext* ctx = getTuningContext(); + if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) { + InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators()); + + if (is_new && realtime_out_ && realtime_out_->good()) { + AppendResultLine(op_signature, params_signature, inserted); + } } - AddImpl(op_signature, params_signature, std::move(best), it->second); } void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, @@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std } } +void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map& validators) { + std::scoped_lock fl{realtime_file_mutex_}; + + if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) { + return; + } + + if (realtime_out_ && realtime_filename_ != filename) { + realtime_out_->flush(); + realtime_out_->close(); + realtime_out_.reset(); + validators_written_ = false; + } + + bool file_exists = false; + bool file_empty = true; + + { + std::ifstream check_file(filename); + if (check_file.good()) { + file_exists = true; + file_empty = (check_file.peek() == std::ifstream::traits_type::eof()); + } + } + + realtime_out_ = std::make_unique(filename, std::ios::out | std::ios::app); + + if (!realtime_out_->good()) { + TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'"); + realtime_out_.reset(); + return; + } + + if(!file_exists || file_empty) { + for(const auto& [key, val] : validators) { + (*realtime_out_) << "Validator," << key << "," << val << std::endl; + realtime_out_->flush(); + } + validators_written_ = true; + + TUNABLE_LOG2("Wrote validators to realtime output file"); + } + + realtime_filename_ = filename; +} + +void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) { + std::scoped_lock fl{realtime_file_mutex_}; + + if(!realtime_out_ || !realtime_out_->good()) { + return; + } + + (*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl; + realtime_out_->flush(); //ensure immediate write to disk + + TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result); +} + +void TuningResultsManager::CloseRealtimeAppend() { + std::scoped_lock fl{realtime_file_mutex_}; + + + if(realtime_out_) { + realtime_out_->flush(); + realtime_out_->close(); + realtime_out_.reset(); + TUNABLE_LOG2("Closed realtime output file"); + } +} + void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { std::scoped_lock l{lock_}; @@ -396,7 +483,6 @@ TuningContext::TuningContext() : tuning_enable_{true}, record_untuned_enable_{false}, manager_initialized_{false}, - write_file_on_exit_{true}, numerics_check_enable_{false}, max_tuning_duration_ms_{30}, max_tuning_iterations_{100}, @@ -417,20 +503,8 @@ TuningContext::~TuningContext() { // but doesn't do any computation itself. return; } - auto filename = GetFilename(); - if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { - if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) { - if (results_count_from_input_file_ > 0) { - TUNABLE_LOG1("additional tuning results available, rewriting file ", filename); - } - else { - TUNABLE_LOG1("writing file ", filename); - } - if (!WriteFile(filename)) { - TUNABLE_LOG1("failed to write file ", filename); - } - } - } + TUNABLE_LOG1("Closing File"); + GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now. if (untuned_file_.good()) { untuned_file_.close(); @@ -511,9 +585,6 @@ std::ofstream& TuningContext::GetUntunedFile(){ return untuned_file_; } -void TuningContext::WriteFileOnExit(bool value) { - write_file_on_exit_ = value; -} void TuningContext::EnableNumericsCheck(bool value) { numerics_check_enable_ = value; @@ -634,11 +705,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() { auto filename = GetFilename(); if (!filename.empty() && !IsRecordUntunedEnabled()) { ReadFile(filename); - // attempt immediately to open file for writing to catch errors early - std::ofstream file(filename, std::ios::out | std::ios::app); - if (!file.good()) { - TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved"); - } } }); return manager_; @@ -744,27 +810,6 @@ bool TuningContext::ReadFile(const std::string& filename_) { return true; } -bool TuningContext::WriteFile(const std::string& filename_) { - std::string filename = filename_.empty() ? GetFilename() : filename_; - std::ofstream file(filename, std::ios::out | std::ios::trunc); - if (!file.good()) { - TUNABLE_LOG1("error opening tuning results file for writing ", filename); - return false; - } - auto validators = GetTuningResultsValidator().GetAllValidators(); - for (const auto& [key, val] : validators) { - file << "Validator," << key << "," << val << std::endl; - } - auto results = GetTuningResultsManager().Dump(); - for (const auto& [op_sig, kernelmap] : results) { - for (const auto& [param_sig, result] : kernelmap) { - file << op_sig << "," << param_sig << "," << result << std::endl; - } - } - file.close(); - return true; -} - namespace { struct MaybeDelete { diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index 5e885d4764d2..95b00ceaa4ca 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager { void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, const std::string& params_signature, const std::string& blas_signature); + + void InitRealtimeAppend( + const std::string& filename, + const std::unordered_map& validators); + + void AppendResultLine(const std::string& op_sig, + const std::string& param_sig, + const ResultEntry& result); + + void CloseRealtimeAppend(); // For clean shutdown private: std::mutex lock_; + std::mutex realtime_file_mutex_; + std::unique_ptr realtime_out_; + std::string realtime_filename_; ResultsMap results_; UntunedMap untuned_results_; + bool validators_written_ = false; }; @@ -185,10 +199,7 @@ class TORCH_CUDA_CPP_API TuningContext { void SetFilename(const std::string& filename, bool insert_device_ordinal=false); std::string GetFilename() const; - void WriteFileOnExit(bool value); - bool ReadFile(const std::string& filename={}); - bool WriteFile(const std::string& filename={}); template void Log(int level, Types... args) { @@ -207,7 +218,6 @@ class TORCH_CUDA_CPP_API TuningContext { bool tuning_enable_; bool record_untuned_enable_; bool manager_initialized_; - bool write_file_on_exit_; bool numerics_check_enable_; int max_tuning_duration_ms_; int max_tuning_iterations_; diff --git a/docs/source/cuda.tunable.md b/docs/source/cuda.tunable.md index 565633fe1881..55c0b5ec9fd7 100644 --- a/docs/source/cuda.tunable.md +++ b/docs/source/cuda.tunable.md @@ -68,14 +68,6 @@ .. autofunction:: get_validators ``` -```{eval-rst} -.. autofunction:: write_file_on_exit -``` - -```{eval-rst} -.. autofunction:: write_file -``` - ```{eval-rst} .. autofunction:: read_file ``` diff --git a/test/test_linalg.py b/test/test_linalg.py index 31b9b680aa84..3cee906a8c42 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4750,6 +4750,7 @@ class TestLinalg(TestCase): @dtypes(*floating_types_and(torch.half)) @precisionOverride({torch.float16: 1e-1}) # TunableOp may occasionally find less precise solution def test_matmul_small_brute_force_tunableop(self, device, dtype): + import os # disable tunableop buffer rotation for all tests everywhere, it can be slow # We set the TunableOp numerical check environment variable here because it is # possible to hit some invalid numerical solutions due to the small matrix sizes. @@ -4777,27 +4778,11 @@ class TestLinalg(TestCase): filename1 = torch.cuda.tunable.get_filename() unique_id = self.id().split(".")[-1] - filename2 = f"{filename1}_tmp1.csv" - filename3 = f"{filename1}_tmp2.csv" ordinal = torch.cuda.current_device() assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv" assert len(torch.cuda.tunable.get_results()) > 0 - assert torch.cuda.tunable.write_file() # use default filename - assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename - torch.cuda.tunable.set_filename(filename3) - assert torch.cuda.tunable.write_file() # use previously set filename - assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True - - with open(filename1) as file1: - file1_contents = file1.read() - with open(filename2) as file2: - file2_contents = file2.read() - with open(filename3) as file3: - file3_contents = file3.read() - assert file1_contents == file2_contents - assert file1_contents == file3_contents - + self.assertTrue(os.path.exists(filename1)) # We need to reset the filename to the default value so we can properly # clean up intermediate files self._set_tunableop_defaults() @@ -4806,6 +4791,7 @@ class TestLinalg(TestCase): @skipCUDAIfNotRocm @dtypes(torch.half) def test_matmul_offline_tunableop(self, device, dtype): + import os # Main offline tunableop test # NOTE: The offline tuning does not support certain tensor # shapes as noted below. Submatrics / matrix slices are @@ -4916,7 +4902,9 @@ class TestLinalg(TestCase): new_results = len(torch.cuda.tunable.get_results()) self.assertGreater(new_results - ref_results, 0) - self.assertTrue(torch.cuda.tunable.write_file()) + + results_filename = torch.cuda.tunable.get_filename() + self.assertTrue(os.path.exists(results_filename)) # Compare Param Signature of untuned and tuned results ok = self._compare_untuned_tuned_entries() @@ -4927,6 +4915,7 @@ class TestLinalg(TestCase): @runOnRocmArch(MI300_ARCH) @dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) def test_scaled_gemm_offline_tunableop(self, device, dtype): + import os # This test is the offline version of test_scaled_gemm_tunableop with self._tunableop_ctx(): @@ -5006,7 +4995,8 @@ class TestLinalg(TestCase): count = 6 self.assertEqual(total_num_results, count) - self.assertTrue(torch.cuda.tunable.write_file()) + results_filename = torch.cuda.tunable.get_filename() + self.assertTrue(os.path.exists(results_filename)) # Compare Param Signature of untuned and tuned results ok = self._compare_untuned_tuned_entries() @@ -5381,6 +5371,7 @@ class TestLinalg(TestCase): @skipCUDAIfNotRocm @dtypes(torch.bfloat16) def test_gemm_bias_offline_tunableop(self, device, dtype): + import os # This test is the offline version of test_gemm_bias_tunableop ordinal = torch.cuda.current_device() @@ -5431,7 +5422,8 @@ class TestLinalg(TestCase): # There must be a new tuning results self.assertEqual(total_num_results, 2) - self.assertTrue(torch.cuda.tunable.write_file()) + results_filename = torch.cuda.tunable.get_filename() + self.assertTrue(os.path.exists(results_filename)) # Compare Param Signature of untuned and tuned results ok = self._compare_untuned_tuned_entries() @@ -5632,7 +5624,8 @@ class TestLinalg(TestCase): 'nn_41_41_41_ld_41_41_41') self.assertTrue(found_result is not None) - self.assertTrue(torch.cuda.tunable.write_file()) + results_filename = torch.cuda.tunable.get_filename() + self.assertTrue(os.path.exists(results_filename)) # Compare Param Signature of untuned and tuned results ok = self._compare_untuned_tuned_entries() @@ -5732,6 +5725,7 @@ class TestLinalg(TestCase): @skipCUDAIfNotRocm @dtypes(torch.float) def test_mm_submatrix_offline_tunableop(self, device, dtype): + import os # Test offline tuning with submatrices # Covers GEMM, ScaledGEMM, and GEMM+bias. ordinal = torch.cuda.current_device() @@ -5862,12 +5856,97 @@ class TestLinalg(TestCase): # There must be a new tuning results self.assertEqual(total_num_results, 10) - self.assertTrue(torch.cuda.tunable.write_file()) + results_filename = torch.cuda.tunable.get_filename() + self.assertTrue(os.path.exists(results_filename)) + # Compare Param Signature of untuned and tuned results ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) + + @onlyCUDA + @skipCUDAIfNotRocm + @dtypes(torch.float32) + def test_ops_append_to_existing_file_tunableop(self, device, dtype): + """If a TunableOp results file already exists (with matching Validator), + new results should be appended (not overwritten).""" + + with self._tunableop_ctx(): + torch.cuda.tunable.set_rotating_buffer_size(0) + + # Seed the existing results file with Validator lines + 1 result line + results_filename = torch.cuda.tunable.get_filename() + validators = torch.cuda.tunable.get_validators() # Iterable[Tuple[str, str]] + + seed_lines = [] + # Each (k, v) becomes a "Validator" line + for k, v in validators: + seed_lines.append(f"Validator,{k},{v}") + + # One arbitrary, plausible matmul result line + seed_lines.append( + "GemmAndBiasTunableOp_float_TN,tn_768_32_1024_ld_1024_1024_768," + "Gemm_Hipblaslt_220580,0.0103395" + ) + + with open(results_filename, "w") as f: + f.write("\n".join(seed_lines) + "\n") + + # Count initial (non-Validator) lines + with open(results_filename) as f: + initial_content = f.read() + initial_lines = [ + l for l in initial_content.split("\n") + if l and not l.startswith("Validator") + ] + initial_count = len(initial_lines) + self.assertGreater(initial_count, 0) # we seeded 1 result line + + # Perform ONE simple matmul + A = torch.randn(37, 53, device=device, dtype=dtype) + B = torch.randn(53, 29, device=device, dtype=dtype) + _ = torch.matmul(A, B) + + # Verify that new results were appended to the same file + with open(results_filename) as f: + final_content = f.read() + final_lines = [ + l for l in final_content.split("\n") + if l and not l.startswith("Validator") + ] + final_count = len(final_lines) + + self.assertGreater(final_count, initial_count) + + @onlyCUDA + @skipCUDAIfNotRocm + @dtypes(torch.float32) + def test_matmul_empty_existing_file_tunableop(self, device, dtype): + """ Test that if an existing results file is empty/corrupted, then the default behaviour should hold """ + with self._tunableop_ctx(): + torch.cuda.tunable.set_rotating_buffer_size(0) + results_filename = torch.cuda.tunable.get_filename() + + # Pre-create an empty results file + with open(results_filename, 'w') as f: + pass # Empty file + + # Use unique random inputs for this test + A = torch.randn(37, 53, device=device, dtype=dtype) + B = torch.randn(53, 29, device=device, dtype=dtype) + + # Direct matmul + C = torch.matmul(A, B) + + with open(results_filename) as f: + content = f.read() + self.assertIn("Validator", content) + result_lines = [l for l in content.split('\n') + if l and not l.startswith('Validator')] + self.assertGreater(len(result_lines), 0) + + @onlyCUDA @skipCUDAIfNotRocm @runOnRocmArch(MI300_ARCH) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9597690fd28d..7f0f80e77a55 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2197,9 +2197,7 @@ def _cuda_tunableop_set_filename( insert_device_ordinal: _bool | None, ) -> None: ... def _cuda_tunableop_get_filename() -> str: ... -def _cuda_tunableop_write_file(filename: str | None) -> _bool: ... def _cuda_tunableop_read_file(filename: str | None) -> _bool: ... -def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ... def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ... def _cuda_tunableop_get_validators() -> tuple[str, str]: ... def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ... diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index c7b80c35c803..41b8de8e78f6 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1653,20 +1653,6 @@ PyObject* THCPModule_cuda_record_untuned_is_enabled( END_HANDLE_TH_ERRORS } -PyObject* THCPModule_cuda_tunableop_write_file_on_exit( - PyObject* _unused, - PyObject* arg) { - HANDLE_TH_ERRORS - TORCH_CHECK( - THPUtils_checkBool(arg), - "cuda_tunableop_write_file_on_exit expects a bool, but got ", - THPUtils_typename(arg)); - at::cuda::tunable::getTuningContext()->WriteFileOnExit( - THPUtils_unpackBool(arg)); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - PyObject* THCPModule_cuda_tunableop_set_max_tuning_duration( PyObject* _unused, PyObject* arg) { @@ -1748,32 +1734,6 @@ PyObject* THCPModule_cuda_tunableop_get_filename( END_HANDLE_TH_ERRORS } -PyObject* THCPModule_cuda_tunableop_write_file( - PyObject* _unused, - PyObject* args) { - HANDLE_TH_ERRORS - PyObject* str = nullptr; - bool success = false; - if (!PyArg_ParseTuple(args, "|O", &str)) { - } - if (str) { - TORCH_CHECK( - THPUtils_checkString(str), - "cuda_tunableop_write_file expects a string, but got ", - THPUtils_typename(str)); - auto filename = THPUtils_unpackString(str); - success = at::cuda::tunable::getTuningContext()->WriteFile(filename); - } else { - success = at::cuda::tunable::getTuningContext()->WriteFile(); - } - if (success) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } - END_HANDLE_TH_ERRORS -} - PyObject* THCPModule_cuda_tunableop_read_file( PyObject* _unused, PyObject* args) { @@ -2127,10 +2087,6 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cuda_record_untuned_is_enabled, METH_NOARGS, nullptr}, - {"_cuda_tunableop_write_file_on_exit", - THCPModule_cuda_tunableop_write_file_on_exit, - METH_O, - nullptr}, {"_cuda_tunableop_set_max_tuning_duration", THCPModule_cuda_tunableop_set_max_tuning_duration, METH_O, @@ -2155,10 +2111,6 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cuda_tunableop_get_filename, METH_NOARGS, nullptr}, - {"_cuda_tunableop_write_file", - THCPModule_cuda_tunableop_write_file, - METH_VARARGS, - nullptr}, {"_cuda_tunableop_read_file", THCPModule_cuda_tunableop_read_file, METH_VARARGS, diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index a1fbd4fdddc2..6b99ea1f8cff 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -206,8 +206,6 @@ __all__ = [ "get_filename", "get_results", "get_validators", - "write_file_on_exit", - "write_file", "read_file", "tune_gemm_in_file", "mgpu_tune_gemm_in_file", @@ -306,25 +304,6 @@ def get_validators() -> tuple[str, str]: return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] -def write_file_on_exit(val: bool) -> None: - r"""During Tuning Context destruction, write file to disk. - - This is useful as a final flush of your results to disk if your application - terminates as result of normal operation or an error. Manual flushing of - your results can be achieved by manually calling ``write_file()``.""" - torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] - - -def write_file(filename: Optional[str] = None) -> bool: - r"""Write results to a CSV file. - - If :attr:`filename` is not given, ``get_filename()`` is called. - """ - if filename is None: - filename = get_filename() - return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] - - def read_file(filename: Optional[str] = None) -> bool: r"""Read results from a TunableOp CSV file. @@ -787,7 +766,6 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: mp_context = mp.get_context("spawn") futures = [] # empty list to hold futures - flush_results = [] # empty list to hold futures # GEMM are assigned to GPUs in a round robin manner h = 0 @@ -809,13 +787,6 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: for future in concurrent.futures.as_completed(futures): future.result() - for g in range(num_gpus): - flush_result = executor.submit(write_file) - flush_results.append(flush_result) - - for flush_result in concurrent.futures.as_completed(flush_results): - flush_result.result() - torch.cuda.synchronize() _gather_tunableop_results()