[ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)

Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
Sarthak Tandon
2025-10-15 22:26:47 +00:00
committed by PyTorch MergeBot
parent e787d532b6
commit 66ea76ec44
10 changed files with 227 additions and 56 deletions

View File

@ -13,6 +13,7 @@
#include <c10/core/ScalarType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
BLASType = "unknown";
}
return BLASType;
}
// Similar to Compute Type in GemmRocblas.h
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
if (!config.enabled) {
return true; // skip when disabled
}
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
if (ok) {
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
} else {
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
}
return ok;
}
}
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};

View File

@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
@ -173,6 +173,7 @@ All python APIs exist in the `torch.cuda.tunable` module.
| get_max_tuning_iterations() -> int | |
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
| get_filename() -> str | |
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
| get_results() -> Tuple[str, str, str, float] | |
| get_validators() -> Tuple[str, str] | |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |

View File

@ -590,12 +590,49 @@ void TuningContext::EnableNumericsCheck(bool value) {
numerics_check_enable_ = value;
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (env == "1") {
return true;
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (!env_opt.has_value()) {
return numerics_cfg_;
}
return numerics_check_enable_;
const std::string& env = env_opt.value();
if (env == "0") {
return NumericalCheckConfig(false, 1e-5, 1e-5);
}
const size_t underscore = env.find('_');
TORCH_CHECK(
underscore != std::string::npos,
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
"Expected 'atol_rtol', got: ",
env);
double atol = 0.0;
double rtol = 0.0;
try {
atol = std::stod(env.substr(0, underscore));
rtol = std::stod(env.substr(underscore + 1));
} catch (const std::exception& e) {
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
}
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
return NumericalCheckConfig(true, atol, rtol);
}
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
numerics_cfg_ = {enabled, atol, rtol};
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto cfg = GetNumericalCheckConfig();
return cfg.enabled || numerics_check_enable_;
}
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {

View File

@ -148,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
GetValidateFuncs validators_;
};
struct NumericalCheckConfig {
bool enabled{false};
double atol{1e-5};
double rtol{1e-5};
NumericalCheckConfig() = default;
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
};
class TORCH_CUDA_CPP_API TuningContext {
public:
TuningContext();
@ -169,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
void EnableNumericsCheck(bool value);
bool IsNumericsCheckEnabled() const;
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
NumericalCheckConfig GetNumericalCheckConfig() const;
void SetMaxTuningDurationMs(int max_duration_ms);
int GetMaxTuningDurationMs() const;
@ -232,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
std::ofstream untuned_file_;
size_t results_count_from_input_file_;
bool is_shutting_down_;
NumericalCheckConfig numerics_cfg_{};
};
TORCH_CUDA_CPP_API TuningContext* getTuningContext();

View File

@ -267,27 +267,10 @@ class TunableOp {
for (size_t i = 0; i < op_names_.size(); i++) {
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
else {
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
// collect a small profile
@ -310,6 +293,22 @@ class TunableOp {
continue;
}
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// for warmup does user set max duration, max iters, or both?
// warmup is skipped by default, i.e. warmup_iter = 0
// warmup will be set to the non-zero value of max_warmup_duration

View File

@ -87,3 +87,7 @@
```{eval-rst}
.. autofunction:: get_rotating_buffer_size
```
```{eval-rst}
.. autofunction:: set_numerical_check_tolerances
```

View File

@ -148,7 +148,6 @@ class TestLinalg(TestCase):
# loop through a list of potentially used
# environment variables.
env_list = ["PYTORCH_TUNABLEOP_BLAS_LOG",
"PYTORCH_TUNABLEOP_NUMERICAL_CHECK",
"PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
for env in env_list:
try:
@ -168,6 +167,7 @@ class TestLinalg(TestCase):
torch.cuda.tunable.set_max_tuning_duration(30)
torch.cuda.tunable.set_max_tuning_iterations(100)
torch.cuda.tunable.set_rotating_buffer_size(-1)
torch.cuda.tunable.set_numerical_check_tolerances(False)
ordinal = torch.cuda.current_device()
# Set filenames to be unique on a per test basis
@ -5144,7 +5144,6 @@ class TestLinalg(TestCase):
@skipCUDAIfNotRocm
@dtypes(torch.bfloat16)
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
import os
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
# run operator first without tuning to ensure all rocm libs are loaded,
# otherwise false positive mem leak
@ -5157,8 +5156,8 @@ class TestLinalg(TestCase):
with self._tunableop_ctx():
torch.cuda.tunable.set_rotating_buffer_size(0)
# enable tunableop numeric check via env variable.
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
# enable tunableop numeric check via API.
torch.cuda.tunable.set_numerical_check_tolerances(True, 0.1, 0.1)
ordinal = torch.cuda.current_device()
@ -6023,6 +6022,48 @@ class TestLinalg(TestCase):
# There must be exactly three kernels only
self.assertEqual(kernel_count, 3)
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float16)
def test_numerical_check_python_binding_tunableop(self, device, dtype):
with self._tunableop_ctx():
torch.cuda.tunable.enable(True)
torch.cuda.tunable.set_numerical_check_tolerances(True)
a = torch.randn(128, 128, device='cuda')
b = torch.randn(128, 128, device='cuda')
_ = a @ b
with self._tunableop_ctx():
torch.cuda.tunable.enable(True)
with self.assertRaisesRegex(RuntimeError, r"positive"):
torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, 1e5)
with self.assertRaisesRegex(RuntimeError, r"positive"):
torch.cuda.tunable.set_numerical_check_tolerances(True, 1e-5, -1e5)
with self.assertRaisesRegex(RuntimeError, r"positive"):
torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, -1e5)
@onlyCUDA
@skipCUDAIfNotRocm
@dtypes(torch.float16, torch.float32)
def test_numerical_check_accuracy_tunableop(self, device, dtype):
shapes = [(127, 193, 61), (251, 317, 73), (89, 149, 41)]
atol, rtol = 1e-2, 1e-1
for (m, k, n) in shapes:
a = torch.randn(m, k, device='cuda')
b = torch.randn(k, n, device='cuda')
torch.cuda.tunable.enable(False)
torch.cuda.tunable.set_numerical_check_tolerances(False)
C_baseline = a @ b
with self._tunableop_ctx():
torch.cuda.tunable.enable(True)
torch.cuda.tunable.set_numerical_check_tolerances(True, atol, rtol)
C_numeric = a @ b
self.assertTrue(torch.allclose(C_baseline, C_numeric, atol=atol, rtol=rtol))
@dtypes(torch.float, torch.complex64)
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)

View File

@ -2202,6 +2202,9 @@ def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ...
def _cuda_tunableop_get_validators() -> tuple[str, str]: ...
def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ...
def _cuda_tunableop_get_rotation_buffer_size() -> _int: ...
def _cuda_tunableop_set_numerical_check_tolerances(
enabled: _bool, atol: _float = 1e-5, rtol: _float = 1e-5
) -> None: ...
class _CudaDeviceProperties:
name: str

View File

@ -1857,6 +1857,64 @@ PyObject* THCPModule_cuda_tunableop_get_rotating_buffer_size(
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_cuda_tunableop_set_numerical_check_tolerances(
PyObject* unused,
PyObject* args) {
HANDLE_TH_ERRORS
PyObject* enabled_obj;
PyObject* atol_obj = NULL;
PyObject* rtol_obj = NULL;
// Parse: required bool, optional float, optional float
if (!PyArg_ParseTuple(args, "O|OO", &enabled_obj, &atol_obj, &rtol_obj)) {
TORCH_CHECK(
false,
"cuda_tunableop_set_numerical_check_tolerances expects (bool[, float[, float]])");
}
TORCH_CHECK(
PyBool_Check(enabled_obj),
"First argument must be a boolean, got ",
THPUtils_typename(enabled_obj));
bool enabled = THPUtils_unpackBool(enabled_obj);
double atol = 1e-5;
double rtol = 1e-5;
if (atol_obj != NULL) {
TORCH_CHECK(
PyFloat_Check(atol_obj),
"Second argument (atol) must be a float, got ",
THPUtils_typename(atol_obj));
atol = PyFloat_AsDouble(atol_obj);
}
if (rtol_obj != NULL) {
TORCH_CHECK(
PyFloat_Check(rtol_obj),
"Third argument (rtol) must be a float, got ",
THPUtils_typename(rtol_obj));
rtol = PyFloat_AsDouble(rtol_obj);
}
TORCH_CHECK(
atol > 0.0 && rtol > 0.0,
"Numerical check tolerances must be positive. Got atol=",
atol,
", rtol=",
rtol);
at::cuda::tunable::getTuningContext()->SetNumericalCheckConfig(
enabled, atol, rtol);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THCPModule_isCurrentStreamCapturing_wrap(
PyObject* self,
PyObject* noargs) {
@ -2131,6 +2189,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_cuda_tunableop_get_rotating_buffer_size,
METH_NOARGS,
nullptr},
{"_cuda_tunableop_set_numerical_check_tolerances",
THCPModule_cuda_tunableop_set_numerical_check_tolerances,
METH_VARARGS,
nullptr},
{nullptr}};
PyMethodDef* THCPModule_methods() {

View File

@ -211,6 +211,7 @@ __all__ = [
"mgpu_tune_gemm_in_file",
"set_rotating_buffer_size",
"get_rotating_buffer_size",
"set_numerical_check_tolerances",
]
@ -327,6 +328,13 @@ def get_rotating_buffer_size() -> int:
return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined]
def set_numerical_check_tolerances(
enable: bool, atol: float = 1e-5, rtol: float = 1e-5
) -> None:
r"""Set the atol and rtol values in numeric check"""
return torch._C._cuda_tunableop_set_numerical_check_tolerances(enable, atol, rtol) # type: ignore[attr-defined]
def tune_gemm_in_file(filename: str) -> None:
r"""tune GEMM in file."""