[ROCm][tunableop] Modified Online Tuning Mode to add Instant Logging (#163965)

- Added instant logging in online tuning mode, so that each tuned GEMM is instantly written
- Allows us to have saved tuning configs, in cases of crashes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163965
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
Sarthak Tandon
2025-10-15 20:02:27 +00:00
committed by PyTorch MergeBot
parent 83f9baf413
commit 7f9b745494
8 changed files with 209 additions and 164 deletions

View File

@ -175,8 +175,6 @@ All python APIs exist in the `torch.cuda.tunable` module.
| get_filename() -> str | | | get_filename() -> str | |
| get_results() -> Tuple[str, str, str, float] | | | get_results() -> Tuple[str, str, str, float] | |
| get_validators() -> Tuple[str, str] | | | get_validators() -> Tuple[str, str] | |
| write_file_on_exit(val: bool) -> None | Default is True. |
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. | | tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. | | mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |

View File

@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
} }
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
std::scoped_lock l{lock_}; bool is_new = false;
ResultEntry inserted = ResultEntry::Null();
auto it = results_.find(op_signature); // ---- mutate maps under results lock ----
if (it == results_.end()) { {
it = results_.insert({op_signature, {}}).first; std::scoped_lock l{lock_};
auto& km = results_[op_signature]; // creates if missing
is_new = (km.find(params_signature) == km.end());
AddImpl(op_signature, params_signature, std::move(best), km);
if (is_new) {
inserted = km.at(params_signature); // snapshot for I/O after unlocking
}
}
if (!is_new) return; // only write once per unique (op, params)
TuningContext* ctx = getTuningContext();
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
if (is_new && realtime_out_ && realtime_out_->good()) {
AppendResultLine(op_signature, params_signature, inserted);
}
} }
AddImpl(op_signature, params_signature, std::move(best), it->second);
} }
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
} }
} }
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
std::scoped_lock fl{realtime_file_mutex_};
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
return;
}
if (realtime_out_ && realtime_filename_ != filename) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
validators_written_ = false;
}
bool file_exists = false;
bool file_empty = true;
{
std::ifstream check_file(filename);
if (check_file.good()) {
file_exists = true;
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
}
}
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
if (!realtime_out_->good()) {
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
realtime_out_.reset();
return;
}
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
TUNABLE_LOG2("Wrote validators to realtime output file");
}
realtime_filename_ = filename;
}
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
std::scoped_lock fl{realtime_file_mutex_};
if(!realtime_out_ || !realtime_out_->good()) {
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
}
void TuningResultsManager::CloseRealtimeAppend() {
std::scoped_lock fl{realtime_file_mutex_};
if(realtime_out_) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
TUNABLE_LOG2("Closed realtime output file");
}
}
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_}; std::scoped_lock l{lock_};
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
tuning_enable_{true}, tuning_enable_{true},
record_untuned_enable_{false}, record_untuned_enable_{false},
manager_initialized_{false}, manager_initialized_{false},
write_file_on_exit_{true},
numerics_check_enable_{false}, numerics_check_enable_{false},
max_tuning_duration_ms_{30}, max_tuning_duration_ms_{30},
max_tuning_iterations_{100}, max_tuning_iterations_{100},
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
// but doesn't do any computation itself. // but doesn't do any computation itself.
return; return;
} }
auto filename = GetFilename(); TUNABLE_LOG1("Closing File");
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) { GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
if (results_count_from_input_file_ > 0) {
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
}
else {
TUNABLE_LOG1("writing file ", filename);
}
if (!WriteFile(filename)) {
TUNABLE_LOG1("failed to write file ", filename);
}
}
}
if (untuned_file_.good()) { if (untuned_file_.good()) {
untuned_file_.close(); untuned_file_.close();
@ -511,9 +585,6 @@ std::ofstream& TuningContext::GetUntunedFile(){
return untuned_file_; return untuned_file_;
} }
void TuningContext::WriteFileOnExit(bool value) {
write_file_on_exit_ = value;
}
void TuningContext::EnableNumericsCheck(bool value) { void TuningContext::EnableNumericsCheck(bool value) {
numerics_check_enable_ = value; numerics_check_enable_ = value;
@ -634,11 +705,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
auto filename = GetFilename(); auto filename = GetFilename();
if (!filename.empty() && !IsRecordUntunedEnabled()) { if (!filename.empty() && !IsRecordUntunedEnabled()) {
ReadFile(filename); ReadFile(filename);
// attempt immediately to open file for writing to catch errors early
std::ofstream file(filename, std::ios::out | std::ios::app);
if (!file.good()) {
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
}
} }
}); });
return manager_; return manager_;
@ -744,27 +810,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
return true; return true;
} }
bool TuningContext::WriteFile(const std::string& filename_) {
std::string filename = filename_.empty() ? GetFilename() : filename_;
std::ofstream file(filename, std::ios::out | std::ios::trunc);
if (!file.good()) {
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
return false;
}
auto validators = GetTuningResultsValidator().GetAllValidators();
for (const auto& [key, val] : validators) {
file << "Validator," << key << "," << val << std::endl;
}
auto results = GetTuningResultsManager().Dump();
for (const auto& [op_sig, kernelmap] : results) {
for (const auto& [param_sig, result] : kernelmap) {
file << op_sig << "," << param_sig << "," << result << std::endl;
}
}
file.close();
return true;
}
namespace { namespace {
struct MaybeDelete { struct MaybeDelete {

View File

@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature); const std::string& params_signature, const std::string& blas_signature);
void InitRealtimeAppend(
const std::string& filename,
const std::unordered_map<std::string, std::string>& validators);
void AppendResultLine(const std::string& op_sig,
const std::string& param_sig,
const ResultEntry& result);
void CloseRealtimeAppend(); // For clean shutdown
private: private:
std::mutex lock_; std::mutex lock_;
std::mutex realtime_file_mutex_;
std::unique_ptr<std::ofstream> realtime_out_;
std::string realtime_filename_;
ResultsMap results_; ResultsMap results_;
UntunedMap untuned_results_; UntunedMap untuned_results_;
bool validators_written_ = false;
}; };
@ -185,10 +199,7 @@ class TORCH_CUDA_CPP_API TuningContext {
void SetFilename(const std::string& filename, bool insert_device_ordinal=false); void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
std::string GetFilename() const; std::string GetFilename() const;
void WriteFileOnExit(bool value);
bool ReadFile(const std::string& filename={}); bool ReadFile(const std::string& filename={});
bool WriteFile(const std::string& filename={});
template<class... Types> template<class... Types>
void Log(int level, Types... args) { void Log(int level, Types... args) {
@ -207,7 +218,6 @@ class TORCH_CUDA_CPP_API TuningContext {
bool tuning_enable_; bool tuning_enable_;
bool record_untuned_enable_; bool record_untuned_enable_;
bool manager_initialized_; bool manager_initialized_;
bool write_file_on_exit_;
bool numerics_check_enable_; bool numerics_check_enable_;
int max_tuning_duration_ms_; int max_tuning_duration_ms_;
int max_tuning_iterations_; int max_tuning_iterations_;

View File

@ -68,14 +68,6 @@
.. autofunction:: get_validators .. autofunction:: get_validators
``` ```
```{eval-rst}
.. autofunction:: write_file_on_exit
```
```{eval-rst}
.. autofunction:: write_file
```
```{eval-rst} ```{eval-rst}
.. autofunction:: read_file .. autofunction:: read_file
``` ```

View File

@ -4750,6 +4750,7 @@ class TestLinalg(TestCase):
@dtypes(*floating_types_and(torch.half)) @dtypes(*floating_types_and(torch.half))
@precisionOverride({torch.float16: 1e-1}) # TunableOp may occasionally find less precise solution @precisionOverride({torch.float16: 1e-1}) # TunableOp may occasionally find less precise solution
def test_matmul_small_brute_force_tunableop(self, device, dtype): def test_matmul_small_brute_force_tunableop(self, device, dtype):
import os
# disable tunableop buffer rotation for all tests everywhere, it can be slow # disable tunableop buffer rotation for all tests everywhere, it can be slow
# We set the TunableOp numerical check environment variable here because it is # We set the TunableOp numerical check environment variable here because it is
# possible to hit some invalid numerical solutions due to the small matrix sizes. # possible to hit some invalid numerical solutions due to the small matrix sizes.
@ -4777,27 +4778,11 @@ class TestLinalg(TestCase):
filename1 = torch.cuda.tunable.get_filename() filename1 = torch.cuda.tunable.get_filename()
unique_id = self.id().split(".")[-1] unique_id = self.id().split(".")[-1]
filename2 = f"{filename1}_tmp1.csv"
filename3 = f"{filename1}_tmp2.csv"
ordinal = torch.cuda.current_device() ordinal = torch.cuda.current_device()
assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv" assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv"
assert len(torch.cuda.tunable.get_results()) > 0 assert len(torch.cuda.tunable.get_results()) > 0
assert torch.cuda.tunable.write_file() # use default filename self.assertTrue(os.path.exists(filename1))
assert torch.cuda.tunable.write_file(filename2) # use custom, one-time filename
torch.cuda.tunable.set_filename(filename3)
assert torch.cuda.tunable.write_file() # use previously set filename
assert torch.cuda.tunable.read_file() # use previously set filename, will ignore duplicates and return True
with open(filename1) as file1:
file1_contents = file1.read()
with open(filename2) as file2:
file2_contents = file2.read()
with open(filename3) as file3:
file3_contents = file3.read()
assert file1_contents == file2_contents
assert file1_contents == file3_contents
# We need to reset the filename to the default value so we can properly # We need to reset the filename to the default value so we can properly
# clean up intermediate files # clean up intermediate files
self._set_tunableop_defaults() self._set_tunableop_defaults()
@ -4806,6 +4791,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNotRocm @skipCUDAIfNotRocm
@dtypes(torch.half) @dtypes(torch.half)
def test_matmul_offline_tunableop(self, device, dtype): def test_matmul_offline_tunableop(self, device, dtype):
import os
# Main offline tunableop test # Main offline tunableop test
# NOTE: The offline tuning does not support certain tensor # NOTE: The offline tuning does not support certain tensor
# shapes as noted below. Submatrics / matrix slices are # shapes as noted below. Submatrics / matrix slices are
@ -4916,7 +4902,9 @@ class TestLinalg(TestCase):
new_results = len(torch.cuda.tunable.get_results()) new_results = len(torch.cuda.tunable.get_results())
self.assertGreater(new_results - ref_results, 0) self.assertGreater(new_results - ref_results, 0)
self.assertTrue(torch.cuda.tunable.write_file())
results_filename = torch.cuda.tunable.get_filename()
self.assertTrue(os.path.exists(results_filename))
# Compare Param Signature of untuned and tuned results # Compare Param Signature of untuned and tuned results
ok = self._compare_untuned_tuned_entries() ok = self._compare_untuned_tuned_entries()
@ -4927,6 +4915,7 @@ class TestLinalg(TestCase):
@runOnRocmArch(MI300_ARCH) @runOnRocmArch(MI300_ARCH)
@dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) @dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
def test_scaled_gemm_offline_tunableop(self, device, dtype): def test_scaled_gemm_offline_tunableop(self, device, dtype):
import os
# This test is the offline version of test_scaled_gemm_tunableop # This test is the offline version of test_scaled_gemm_tunableop
with self._tunableop_ctx(): with self._tunableop_ctx():
@ -5006,7 +4995,8 @@ class TestLinalg(TestCase):
count = 6 count = 6
self.assertEqual(total_num_results, count) self.assertEqual(total_num_results, count)
self.assertTrue(torch.cuda.tunable.write_file()) results_filename = torch.cuda.tunable.get_filename()
self.assertTrue(os.path.exists(results_filename))
# Compare Param Signature of untuned and tuned results # Compare Param Signature of untuned and tuned results
ok = self._compare_untuned_tuned_entries() ok = self._compare_untuned_tuned_entries()
@ -5381,6 +5371,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNotRocm @skipCUDAIfNotRocm
@dtypes(torch.bfloat16) @dtypes(torch.bfloat16)
def test_gemm_bias_offline_tunableop(self, device, dtype): def test_gemm_bias_offline_tunableop(self, device, dtype):
import os
# This test is the offline version of test_gemm_bias_tunableop # This test is the offline version of test_gemm_bias_tunableop
ordinal = torch.cuda.current_device() ordinal = torch.cuda.current_device()
@ -5431,7 +5422,8 @@ class TestLinalg(TestCase):
# There must be a new tuning results # There must be a new tuning results
self.assertEqual(total_num_results, 2) self.assertEqual(total_num_results, 2)
self.assertTrue(torch.cuda.tunable.write_file()) results_filename = torch.cuda.tunable.get_filename()
self.assertTrue(os.path.exists(results_filename))
# Compare Param Signature of untuned and tuned results # Compare Param Signature of untuned and tuned results
ok = self._compare_untuned_tuned_entries() ok = self._compare_untuned_tuned_entries()
@ -5632,7 +5624,8 @@ class TestLinalg(TestCase):
'nn_41_41_41_ld_41_41_41') 'nn_41_41_41_ld_41_41_41')
self.assertTrue(found_result is not None) self.assertTrue(found_result is not None)
self.assertTrue(torch.cuda.tunable.write_file()) results_filename = torch.cuda.tunable.get_filename()
self.assertTrue(os.path.exists(results_filename))
# Compare Param Signature of untuned and tuned results # Compare Param Signature of untuned and tuned results
ok = self._compare_untuned_tuned_entries() ok = self._compare_untuned_tuned_entries()
@ -5732,6 +5725,7 @@ class TestLinalg(TestCase):
@skipCUDAIfNotRocm @skipCUDAIfNotRocm
@dtypes(torch.float) @dtypes(torch.float)
def test_mm_submatrix_offline_tunableop(self, device, dtype): def test_mm_submatrix_offline_tunableop(self, device, dtype):
import os
# Test offline tuning with submatrices # Test offline tuning with submatrices
# Covers GEMM, ScaledGEMM, and GEMM+bias. # Covers GEMM, ScaledGEMM, and GEMM+bias.
ordinal = torch.cuda.current_device() ordinal = torch.cuda.current_device()
@ -5862,12 +5856,97 @@ class TestLinalg(TestCase):
# There must be a new tuning results # There must be a new tuning results
self.assertEqual(total_num_results, 10) self.assertEqual(total_num_results, 10)
self.assertTrue(torch.cuda.tunable.write_file()) results_filename = torch.cuda.tunable.get_filename()
self.assertTrue(os.path.exists(results_filename))
# Compare Param Signature of untuned and tuned results # Compare Param Signature of untuned and tuned results
ok = self._compare_untuned_tuned_entries() ok = self._compare_untuned_tuned_entries()
self.assertTrue(ok) self.assertTrue(ok)
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float32)
def test_ops_append_to_existing_file_tunableop(self, device, dtype):
"""If a TunableOp results file already exists (with matching Validator),
new results should be appended (not overwritten)."""
with self._tunableop_ctx():
torch.cuda.tunable.set_rotating_buffer_size(0)
# Seed the existing results file with Validator lines + 1 result line
results_filename = torch.cuda.tunable.get_filename()
validators = torch.cuda.tunable.get_validators() # Iterable[Tuple[str, str]]
seed_lines = []
# Each (k, v) becomes a "Validator" line
for k, v in validators:
seed_lines.append(f"Validator,{k},{v}")
# One arbitrary, plausible matmul result line
seed_lines.append(
"GemmAndBiasTunableOp_float_TN,tn_768_32_1024_ld_1024_1024_768,"
"Gemm_Hipblaslt_220580,0.0103395"
)
with open(results_filename, "w") as f:
f.write("\n".join(seed_lines) + "\n")
# Count initial (non-Validator) lines
with open(results_filename) as f:
initial_content = f.read()
initial_lines = [
l for l in initial_content.split("\n")
if l and not l.startswith("Validator")
]
initial_count = len(initial_lines)
self.assertGreater(initial_count, 0) # we seeded 1 result line
# Perform ONE simple matmul
A = torch.randn(37, 53, device=device, dtype=dtype)
B = torch.randn(53, 29, device=device, dtype=dtype)
_ = torch.matmul(A, B)
# Verify that new results were appended to the same file
with open(results_filename) as f:
final_content = f.read()
final_lines = [
l for l in final_content.split("\n")
if l and not l.startswith("Validator")
]
final_count = len(final_lines)
self.assertGreater(final_count, initial_count)
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float32)
def test_matmul_empty_existing_file_tunableop(self, device, dtype):
""" Test that if an existing results file is empty/corrupted, then the default behaviour should hold """
with self._tunableop_ctx():
torch.cuda.tunable.set_rotating_buffer_size(0)
results_filename = torch.cuda.tunable.get_filename()
# Pre-create an empty results file
with open(results_filename, 'w') as f:
pass # Empty file
# Use unique random inputs for this test
A = torch.randn(37, 53, device=device, dtype=dtype)
B = torch.randn(53, 29, device=device, dtype=dtype)
# Direct matmul
C = torch.matmul(A, B)
with open(results_filename) as f:
content = f.read()
self.assertIn("Validator", content)
result_lines = [l for l in content.split('\n')
if l and not l.startswith('Validator')]
self.assertGreater(len(result_lines), 0)
@onlyCUDA @onlyCUDA
@skipCUDAIfNotRocm @skipCUDAIfNotRocm
@runOnRocmArch(MI300_ARCH) @runOnRocmArch(MI300_ARCH)

View File

@ -2197,9 +2197,7 @@ def _cuda_tunableop_set_filename(
insert_device_ordinal: _bool | None, insert_device_ordinal: _bool | None,
) -> None: ... ) -> None: ...
def _cuda_tunableop_get_filename() -> str: ... def _cuda_tunableop_get_filename() -> str: ...
def _cuda_tunableop_write_file(filename: str | None) -> _bool: ...
def _cuda_tunableop_read_file(filename: str | None) -> _bool: ... def _cuda_tunableop_read_file(filename: str | None) -> _bool: ...
def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ...
def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ... def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ...
def _cuda_tunableop_get_validators() -> tuple[str, str]: ... def _cuda_tunableop_get_validators() -> tuple[str, str]: ...
def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ... def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ...

View File

@ -1653,20 +1653,6 @@ PyObject* THCPModule_cuda_record_untuned_is_enabled(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject* THCPModule_cuda_tunableop_write_file_on_exit(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
THPUtils_checkBool(arg),
"cuda_tunableop_write_file_on_exit expects a bool, but got ",
THPUtils_typename(arg));
at::cuda::tunable::getTuningContext()->WriteFileOnExit(
THPUtils_unpackBool(arg));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cuda_tunableop_set_max_tuning_duration( PyObject* THCPModule_cuda_tunableop_set_max_tuning_duration(
PyObject* _unused, PyObject* _unused,
PyObject* arg) { PyObject* arg) {
@ -1748,32 +1734,6 @@ PyObject* THCPModule_cuda_tunableop_get_filename(
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject* THCPModule_cuda_tunableop_write_file(
PyObject* _unused,
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* str = nullptr;
bool success = false;
if (!PyArg_ParseTuple(args, "|O", &str)) {
}
if (str) {
TORCH_CHECK(
THPUtils_checkString(str),
"cuda_tunableop_write_file expects a string, but got ",
THPUtils_typename(str));
auto filename = THPUtils_unpackString(str);
success = at::cuda::tunable::getTuningContext()->WriteFile(filename);
} else {
success = at::cuda::tunable::getTuningContext()->WriteFile();
}
if (success) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cuda_tunableop_read_file( PyObject* THCPModule_cuda_tunableop_read_file(
PyObject* _unused, PyObject* _unused,
PyObject* args) { PyObject* args) {
@ -2127,10 +2087,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cuda_record_untuned_is_enabled, THCPModule_cuda_record_untuned_is_enabled,
METH_NOARGS, METH_NOARGS,
nullptr}, nullptr},
{"_cuda_tunableop_write_file_on_exit",
THCPModule_cuda_tunableop_write_file_on_exit,
METH_O,
nullptr},
{"_cuda_tunableop_set_max_tuning_duration", {"_cuda_tunableop_set_max_tuning_duration",
THCPModule_cuda_tunableop_set_max_tuning_duration, THCPModule_cuda_tunableop_set_max_tuning_duration,
METH_O, METH_O,
@ -2155,10 +2111,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cuda_tunableop_get_filename, THCPModule_cuda_tunableop_get_filename,
METH_NOARGS, METH_NOARGS,
nullptr}, nullptr},
{"_cuda_tunableop_write_file",
THCPModule_cuda_tunableop_write_file,
METH_VARARGS,
nullptr},
{"_cuda_tunableop_read_file", {"_cuda_tunableop_read_file",
THCPModule_cuda_tunableop_read_file, THCPModule_cuda_tunableop_read_file,
METH_VARARGS, METH_VARARGS,

View File

@ -206,8 +206,6 @@ __all__ = [
"get_filename", "get_filename",
"get_results", "get_results",
"get_validators", "get_validators",
"write_file_on_exit",
"write_file",
"read_file", "read_file",
"tune_gemm_in_file", "tune_gemm_in_file",
"mgpu_tune_gemm_in_file", "mgpu_tune_gemm_in_file",
@ -306,25 +304,6 @@ def get_validators() -> tuple[str, str]:
return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
def write_file_on_exit(val: bool) -> None:
r"""During Tuning Context destruction, write file to disk.
This is useful as a final flush of your results to disk if your application
terminates as result of normal operation or an error. Manual flushing of
your results can be achieved by manually calling ``write_file()``."""
torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined]
def write_file(filename: Optional[str] = None) -> bool:
r"""Write results to a CSV file.
If :attr:`filename` is not given, ``get_filename()`` is called.
"""
if filename is None:
filename = get_filename()
return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined]
def read_file(filename: Optional[str] = None) -> bool: def read_file(filename: Optional[str] = None) -> bool:
r"""Read results from a TunableOp CSV file. r"""Read results from a TunableOp CSV file.
@ -787,7 +766,6 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
mp_context = mp.get_context("spawn") mp_context = mp.get_context("spawn")
futures = [] # empty list to hold futures futures = [] # empty list to hold futures
flush_results = [] # empty list to hold futures
# GEMM are assigned to GPUs in a round robin manner # GEMM are assigned to GPUs in a round robin manner
h = 0 h = 0
@ -809,13 +787,6 @@ def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
future.result() future.result()
for g in range(num_gpus):
flush_result = executor.submit(write_file)
flush_results.append(flush_result)
for flush_result in concurrent.futures.as_completed(flush_results):
flush_result.result()
torch.cuda.synchronize() torch.cuda.synchronize()
_gather_tunableop_results() _gather_tunableop_results()