From 66ea76ec44c0cfd0499f9544201b1cdce6d5cb4e Mon Sep 17 00:00:00 2001 From: Sarthak Tandon Date: Wed, 15 Oct 2025 22:26:47 +0000 Subject: [PATCH] [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079) Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily --- aten/src/ATen/cuda/tunable/GemmCommon.h | 52 +++++++++++---------- aten/src/ATen/cuda/tunable/README.md | 3 +- aten/src/ATen/cuda/tunable/Tunable.cpp | 47 +++++++++++++++++-- aten/src/ATen/cuda/tunable/Tunable.h | 14 ++++++ aten/src/ATen/cuda/tunable/TunableOp.h | 41 ++++++++-------- docs/source/cuda.tunable.md | 4 ++ test/test_linalg.py | 49 +++++++++++++++++-- torch/_C/__init__.pyi.in | 3 ++ torch/csrc/cuda/Module.cpp | 62 +++++++++++++++++++++++++ torch/cuda/tunable.py | 8 ++++ 10 files changed, 227 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 8478aa4d4cf4..5d9e33b2b5b2 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) { BLASType = "unknown"; } return BLASType; + } // Similar to Compute Type in GemmRocblas.h @@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio namespace detail { -static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) { +static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) { + + if (!config.enabled) { + return true; // skip when disabled + } + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA); - // comparison done as 1D tensor at::Tensor ref = at::from_blob(c, {size}, options); at::Tensor oth = at::from_blob(other_c, {size}, options); at::Tensor ref_float = ref.to(at::kFloat); at::Tensor oth_float = oth.to(at::kFloat); - std::vector atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; - std::vector rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5}; - double last_succeed_atol = 1; - double last_succeed_rtol = 1; - for (auto& atol : atols) { - for (auto& rtol : rtols) { - if (at::allclose(ref_float, oth_float, rtol, atol)) { - last_succeed_atol = atol; - last_succeed_rtol = rtol; - } - } - } - if (last_succeed_atol == 1) { - return false; - } - else { - TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol); - } - return true; + const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol); + if (ok) { + TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol); + } else { + TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol); + } + return ok; } } @@ -355,8 +349,10 @@ struct GemmParams : OpParams { } TuningStatus NumericalCheck(GemmParams *other) { + auto* ctx = getTuningContext(); + auto cfg = ctx->GetNumericalCheckConfig(); auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; } char transa{}; @@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams { } TuningStatus NumericalCheck(GemmAndBiasParams *other) { + auto* ctx = getTuningContext(); + auto cfg = ctx->GetNumericalCheckConfig(); auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; } char transa{}; @@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams { } TuningStatus NumericalCheck(GemmStridedBatchedParams *other) { + auto* ctx = getTuningContext(); + auto cfg = ctx->GetNumericalCheckConfig(); auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; } char transa{}; @@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams { } TuningStatus NumericalCheck(ScaledGemmParams *other) { - return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL; + auto* ctx = getTuningContext(); + auto cfg = ctx->GetNumericalCheckConfig(); + return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL; } char transa{}; diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index 4816886ecc86..db31af9259a5 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins | PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. | | PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. | | PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. | -| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". | | PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. | | PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. | | PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. | @@ -173,6 +173,7 @@ All python APIs exist in the `torch.cuda.tunable` module. | get_max_tuning_iterations() -> int | | | set_filename(filename: str, insert_device_ordinal: bool = False) -> None | | | get_filename() -> str | | +| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5. | get_results() -> Tuple[str, str, str, float] | | | get_validators() -> Tuple[str, str] | | | read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index c4d5fa261fc2..c5ea0c6dd17c 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -590,12 +590,49 @@ void TuningContext::EnableNumericsCheck(bool value) { numerics_check_enable_ = value; } -bool TuningContext::IsNumericsCheckEnabled() const { - const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env == "1") { - return true; +NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const { + const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); + + if (!env_opt.has_value()) { + return numerics_cfg_; } - return numerics_check_enable_; + + const std::string& env = env_opt.value(); + + if (env == "0") { + return NumericalCheckConfig(false, 1e-5, 1e-5); + } + + const size_t underscore = env.find('_'); + + TORCH_CHECK( + underscore != std::string::npos, + "Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. " + "Expected 'atol_rtol', got: ", + env); + + double atol = 0.0; + double rtol = 0.0; + + try { + atol = std::stod(env.substr(0, underscore)); + rtol = std::stod(env.substr(underscore + 1)); + } catch (const std::exception& e) { + TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what()); + } + + TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol); + return NumericalCheckConfig(true, atol, rtol); +} + +void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) { + TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive"); + numerics_cfg_ = {enabled, atol, rtol}; +} + +bool TuningContext::IsNumericsCheckEnabled() const { + const auto cfg = GetNumericalCheckConfig(); + return cfg.enabled || numerics_check_enable_; } void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) { diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index 95b00ceaa4ca..17b4ea34ddf6 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -148,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator { GetValidateFuncs validators_; }; +struct NumericalCheckConfig { + bool enabled{false}; + double atol{1e-5}; + double rtol{1e-5}; + + NumericalCheckConfig() = default; + NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {} +}; + + class TORCH_CUDA_CPP_API TuningContext { public: TuningContext(); @@ -169,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext { void EnableNumericsCheck(bool value); bool IsNumericsCheckEnabled() const; + void SetNumericalCheckConfig(bool enabled, double atol, double rtol); + NumericalCheckConfig GetNumericalCheckConfig() const; void SetMaxTuningDurationMs(int max_duration_ms); int GetMaxTuningDurationMs() const; @@ -232,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext { std::ofstream untuned_file_; size_t results_count_from_input_file_; bool is_shutting_down_; + + NumericalCheckConfig numerics_cfg_{}; }; TORCH_CUDA_CPP_API TuningContext* getTuningContext(); diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h index b4b983dc739c..d7bf0e6d93d8 100644 --- a/aten/src/ATen/cuda/tunable/TunableOp.h +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -267,27 +267,10 @@ class TunableOp { for (size_t i = 0; i < op_names_.size(); i++) { auto* candidate = ops_[op_names_[i]].get(); // borrow pointer - if (do_numerics_check) { - ParamsT* numerical_params = params->DeepCopy(false); - auto status = candidate->Call(numerical_params); - if (status != OK) { - numerical_params->Delete(); - TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); - continue; - } - status = reference_params->NumericalCheck(numerical_params); - numerical_params->Delete(); - if (status != OK) { - TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); - continue; - } - } - else { - auto status = candidate->Call(reusable_params[0]); - if (status != OK) { - TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); - continue; - } + auto status = candidate->Call(reusable_params[0]); + if (status != OK) { + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; } // collect a small profile @@ -310,6 +293,22 @@ class TunableOp { continue; } + if (do_numerics_check) { + ParamsT* numerical_params = params->DeepCopy(false); + auto status = candidate->Call(numerical_params); + if (status != OK) { + numerical_params->Delete(); + TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + status = reference_params->NumericalCheck(numerical_params); + numerical_params->Delete(); + if (status != OK) { + TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); + continue; + } + } + // for warmup does user set max duration, max iters, or both? // warmup is skipped by default, i.e. warmup_iter = 0 // warmup will be set to the non-zero value of max_warmup_duration diff --git a/docs/source/cuda.tunable.md b/docs/source/cuda.tunable.md index 55c0b5ec9fd7..6d877e05397b 100644 --- a/docs/source/cuda.tunable.md +++ b/docs/source/cuda.tunable.md @@ -87,3 +87,7 @@ ```{eval-rst} .. autofunction:: get_rotating_buffer_size ``` + +```{eval-rst} +.. autofunction:: set_numerical_check_tolerances +``` \ No newline at end of file diff --git a/test/test_linalg.py b/test/test_linalg.py index 3cee906a8c42..31ece7df7a79 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -148,7 +148,6 @@ class TestLinalg(TestCase): # loop through a list of potentially used # environment variables. env_list = ["PYTORCH_TUNABLEOP_BLAS_LOG", - "PYTORCH_TUNABLEOP_NUMERICAL_CHECK", "PYTORCH_TUNABLEOP_UNTUNED_FILENAME"] for env in env_list: try: @@ -168,6 +167,7 @@ class TestLinalg(TestCase): torch.cuda.tunable.set_max_tuning_duration(30) torch.cuda.tunable.set_max_tuning_iterations(100) torch.cuda.tunable.set_rotating_buffer_size(-1) + torch.cuda.tunable.set_numerical_check_tolerances(False) ordinal = torch.cuda.current_device() # Set filenames to be unique on a per test basis @@ -5144,7 +5144,6 @@ class TestLinalg(TestCase): @skipCUDAIfNotRocm @dtypes(torch.bfloat16) def test_numeric_check_leak_tunableop_rocm(self, device, dtype): - import os from torch.testing._internal.common_utils import CudaMemoryLeakCheck # run operator first without tuning to ensure all rocm libs are loaded, # otherwise false positive mem leak @@ -5157,8 +5156,8 @@ class TestLinalg(TestCase): with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) - # enable tunableop numeric check via env variable. - os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1" + # enable tunableop numeric check via API. + torch.cuda.tunable.set_numerical_check_tolerances(True, 0.1, 0.1) ordinal = torch.cuda.current_device() @@ -6023,6 +6022,48 @@ class TestLinalg(TestCase): # There must be exactly three kernels only self.assertEqual(kernel_count, 3) + @onlyCUDA + @skipCUDAIfNotRocm + @dtypes(torch.float16) + def test_numerical_check_python_binding_tunableop(self, device, dtype): + with self._tunableop_ctx(): + torch.cuda.tunable.enable(True) + torch.cuda.tunable.set_numerical_check_tolerances(True) + + a = torch.randn(128, 128, device='cuda') + b = torch.randn(128, 128, device='cuda') + + _ = a @ b + + with self._tunableop_ctx(): + torch.cuda.tunable.enable(True) + with self.assertRaisesRegex(RuntimeError, r"positive"): + torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, 1e5) + with self.assertRaisesRegex(RuntimeError, r"positive"): + torch.cuda.tunable.set_numerical_check_tolerances(True, 1e-5, -1e5) + with self.assertRaisesRegex(RuntimeError, r"positive"): + torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, -1e5) + + @onlyCUDA + @skipCUDAIfNotRocm + @dtypes(torch.float16, torch.float32) + def test_numerical_check_accuracy_tunableop(self, device, dtype): + shapes = [(127, 193, 61), (251, 317, 73), (89, 149, 41)] + atol, rtol = 1e-2, 1e-1 + + for (m, k, n) in shapes: + a = torch.randn(m, k, device='cuda') + b = torch.randn(k, n, device='cuda') + torch.cuda.tunable.enable(False) + torch.cuda.tunable.set_numerical_check_tolerances(False) + C_baseline = a @ b + with self._tunableop_ctx(): + torch.cuda.tunable.enable(True) + torch.cuda.tunable.set_numerical_check_tolerances(True, atol, rtol) + C_numeric = a @ b + self.assertTrue(torch.allclose(C_baseline, C_numeric, atol=atol, rtol=rtol)) + + @dtypes(torch.float, torch.complex64) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7f0f80e77a55..c7e2c608ab53 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2202,6 +2202,9 @@ def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ... def _cuda_tunableop_get_validators() -> tuple[str, str]: ... def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ... def _cuda_tunableop_get_rotation_buffer_size() -> _int: ... +def _cuda_tunableop_set_numerical_check_tolerances( + enabled: _bool, atol: _float = 1e-5, rtol: _float = 1e-5 +) -> None: ... class _CudaDeviceProperties: name: str diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 41b8de8e78f6..0950192457d6 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1857,6 +1857,64 @@ PyObject* THCPModule_cuda_tunableop_get_rotating_buffer_size( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cuda_tunableop_set_numerical_check_tolerances( + PyObject* unused, + PyObject* args) { + HANDLE_TH_ERRORS + + PyObject* enabled_obj; + PyObject* atol_obj = NULL; + PyObject* rtol_obj = NULL; + + // Parse: required bool, optional float, optional float + if (!PyArg_ParseTuple(args, "O|OO", &enabled_obj, &atol_obj, &rtol_obj)) { + TORCH_CHECK( + false, + "cuda_tunableop_set_numerical_check_tolerances expects (bool[, float[, float]])"); + } + + TORCH_CHECK( + PyBool_Check(enabled_obj), + "First argument must be a boolean, got ", + THPUtils_typename(enabled_obj)); + + bool enabled = THPUtils_unpackBool(enabled_obj); + + double atol = 1e-5; + double rtol = 1e-5; + + if (atol_obj != NULL) { + TORCH_CHECK( + PyFloat_Check(atol_obj), + "Second argument (atol) must be a float, got ", + THPUtils_typename(atol_obj)); + + atol = PyFloat_AsDouble(atol_obj); + } + + if (rtol_obj != NULL) { + TORCH_CHECK( + PyFloat_Check(rtol_obj), + "Third argument (rtol) must be a float, got ", + THPUtils_typename(rtol_obj)); + + rtol = PyFloat_AsDouble(rtol_obj); + } + + TORCH_CHECK( + atol > 0.0 && rtol > 0.0, + "Numerical check tolerances must be positive. Got atol=", + atol, + ", rtol=", + rtol); + + at::cuda::tunable::getTuningContext()->SetNumericalCheckConfig( + enabled, atol, rtol); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static PyObject* THCPModule_isCurrentStreamCapturing_wrap( PyObject* self, PyObject* noargs) { @@ -2131,6 +2189,10 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cuda_tunableop_get_rotating_buffer_size, METH_NOARGS, nullptr}, + {"_cuda_tunableop_set_numerical_check_tolerances", + THCPModule_cuda_tunableop_set_numerical_check_tolerances, + METH_VARARGS, + nullptr}, {nullptr}}; PyMethodDef* THCPModule_methods() { diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index 6b99ea1f8cff..262c6870d400 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -211,6 +211,7 @@ __all__ = [ "mgpu_tune_gemm_in_file", "set_rotating_buffer_size", "get_rotating_buffer_size", + "set_numerical_check_tolerances", ] @@ -327,6 +328,13 @@ def get_rotating_buffer_size() -> int: return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined] +def set_numerical_check_tolerances( + enable: bool, atol: float = 1e-5, rtol: float = 1e-5 +) -> None: + r"""Set the atol and rtol values in numeric check""" + return torch._C._cuda_tunableop_set_numerical_check_tolerances(enable, atol, rtol) # type: ignore[attr-defined] + + def tune_gemm_in_file(filename: str) -> None: r"""tune GEMM in file."""