mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)
Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
e787d532b6
commit
66ea76ec44
@ -13,6 +13,7 @@
|
|||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
||||||
#include <ATen/cuda/tunable/TunableOp.h>
|
#include <ATen/cuda/tunable/TunableOp.h>
|
||||||
|
#include <ATen/cuda/tunable/Tunable.h>
|
||||||
#include <ATen/cuda/CUDABlas.h>
|
#include <ATen/cuda/CUDABlas.h>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
#include <c10/util/StringUtil.h>
|
#include <c10/util/StringUtil.h>
|
||||||
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
|||||||
BLASType = "unknown";
|
BLASType = "unknown";
|
||||||
}
|
}
|
||||||
return BLASType;
|
return BLASType;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Similar to Compute Type in GemmRocblas.h
|
// Similar to Compute Type in GemmRocblas.h
|
||||||
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
||||||
|
|
||||||
|
if (!config.enabled) {
|
||||||
|
return true; // skip when disabled
|
||||||
|
}
|
||||||
|
|
||||||
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
||||||
// comparison done as 1D tensor
|
|
||||||
at::Tensor ref = at::from_blob(c, {size}, options);
|
at::Tensor ref = at::from_blob(c, {size}, options);
|
||||||
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
||||||
at::Tensor ref_float = ref.to(at::kFloat);
|
at::Tensor ref_float = ref.to(at::kFloat);
|
||||||
at::Tensor oth_float = oth.to(at::kFloat);
|
at::Tensor oth_float = oth.to(at::kFloat);
|
||||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
|
||||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
|
||||||
double last_succeed_atol = 1;
|
|
||||||
double last_succeed_rtol = 1;
|
|
||||||
for (auto& atol : atols) {
|
|
||||||
for (auto& rtol : rtols) {
|
|
||||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
|
||||||
last_succeed_atol = atol;
|
|
||||||
last_succeed_rtol = rtol;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (last_succeed_atol == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
||||||
|
if (ok) {
|
||||||
|
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
||||||
|
} else {
|
||||||
|
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
||||||
|
}
|
||||||
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||||
|
auto* ctx = getTuningContext();
|
||||||
|
auto cfg = ctx->GetNumericalCheckConfig();
|
||||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
char transa{};
|
char transa{};
|
||||||
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||||
|
auto* ctx = getTuningContext();
|
||||||
|
auto cfg = ctx->GetNumericalCheckConfig();
|
||||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
char transa{};
|
char transa{};
|
||||||
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||||
|
auto* ctx = getTuningContext();
|
||||||
|
auto cfg = ctx->GetNumericalCheckConfig();
|
||||||
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
||||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
char transa{};
|
char transa{};
|
||||||
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
||||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
auto* ctx = getTuningContext();
|
||||||
|
auto cfg = ctx->GetNumericalCheckConfig();
|
||||||
|
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||||
}
|
}
|
||||||
|
|
||||||
char transa{};
|
char transa{};
|
||||||
|
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
|
|||||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
||||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
||||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
||||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
|
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
|
||||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
||||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
||||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
||||||
@ -173,6 +173,7 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
|||||||
| get_max_tuning_iterations() -> int | |
|
| get_max_tuning_iterations() -> int | |
|
||||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
||||||
| get_filename() -> str | |
|
| get_filename() -> str | |
|
||||||
|
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
|
||||||
| get_results() -> Tuple[str, str, str, float] | |
|
| get_results() -> Tuple[str, str, str, float] | |
|
||||||
| get_validators() -> Tuple[str, str] | |
|
| get_validators() -> Tuple[str, str] | |
|
||||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||||
|
@ -590,12 +590,49 @@ void TuningContext::EnableNumericsCheck(bool value) {
|
|||||||
numerics_check_enable_ = value;
|
numerics_check_enable_ = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
|
||||||
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||||
if (env == "1") {
|
|
||||||
return true;
|
if (!env_opt.has_value()) {
|
||||||
|
return numerics_cfg_;
|
||||||
}
|
}
|
||||||
return numerics_check_enable_;
|
|
||||||
|
const std::string& env = env_opt.value();
|
||||||
|
|
||||||
|
if (env == "0") {
|
||||||
|
return NumericalCheckConfig(false, 1e-5, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t underscore = env.find('_');
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
underscore != std::string::npos,
|
||||||
|
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
|
||||||
|
"Expected 'atol_rtol', got: ",
|
||||||
|
env);
|
||||||
|
|
||||||
|
double atol = 0.0;
|
||||||
|
double rtol = 0.0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
atol = std::stod(env.substr(0, underscore));
|
||||||
|
rtol = std::stod(env.substr(underscore + 1));
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
|
||||||
|
return NumericalCheckConfig(true, atol, rtol);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
|
||||||
|
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
|
||||||
|
numerics_cfg_ = {enabled, atol, rtol};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||||
|
const auto cfg = GetNumericalCheckConfig();
|
||||||
|
return cfg.enabled || numerics_check_enable_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||||
|
@ -148,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
|
|||||||
GetValidateFuncs validators_;
|
GetValidateFuncs validators_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct NumericalCheckConfig {
|
||||||
|
bool enabled{false};
|
||||||
|
double atol{1e-5};
|
||||||
|
double rtol{1e-5};
|
||||||
|
|
||||||
|
NumericalCheckConfig() = default;
|
||||||
|
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class TORCH_CUDA_CPP_API TuningContext {
|
class TORCH_CUDA_CPP_API TuningContext {
|
||||||
public:
|
public:
|
||||||
TuningContext();
|
TuningContext();
|
||||||
@ -169,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
|||||||
|
|
||||||
void EnableNumericsCheck(bool value);
|
void EnableNumericsCheck(bool value);
|
||||||
bool IsNumericsCheckEnabled() const;
|
bool IsNumericsCheckEnabled() const;
|
||||||
|
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
||||||
|
NumericalCheckConfig GetNumericalCheckConfig() const;
|
||||||
|
|
||||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||||
int GetMaxTuningDurationMs() const;
|
int GetMaxTuningDurationMs() const;
|
||||||
@ -232,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
|||||||
std::ofstream untuned_file_;
|
std::ofstream untuned_file_;
|
||||||
size_t results_count_from_input_file_;
|
size_t results_count_from_input_file_;
|
||||||
bool is_shutting_down_;
|
bool is_shutting_down_;
|
||||||
|
|
||||||
|
NumericalCheckConfig numerics_cfg_{};
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
||||||
|
@ -267,27 +267,10 @@ class TunableOp {
|
|||||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||||
|
|
||||||
if (do_numerics_check) {
|
auto status = candidate->Call(reusable_params[0]);
|
||||||
ParamsT* numerical_params = params->DeepCopy(false);
|
if (status != OK) {
|
||||||
auto status = candidate->Call(numerical_params);
|
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||||
if (status != OK) {
|
continue;
|
||||||
numerical_params->Delete();
|
|
||||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
status = reference_params->NumericalCheck(numerical_params);
|
|
||||||
numerical_params->Delete();
|
|
||||||
if (status != OK) {
|
|
||||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto status = candidate->Call(reusable_params[0]);
|
|
||||||
if (status != OK) {
|
|
||||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// collect a small profile
|
// collect a small profile
|
||||||
@ -310,6 +293,22 @@ class TunableOp {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (do_numerics_check) {
|
||||||
|
ParamsT* numerical_params = params->DeepCopy(false);
|
||||||
|
auto status = candidate->Call(numerical_params);
|
||||||
|
if (status != OK) {
|
||||||
|
numerical_params->Delete();
|
||||||
|
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
status = reference_params->NumericalCheck(numerical_params);
|
||||||
|
numerical_params->Delete();
|
||||||
|
if (status != OK) {
|
||||||
|
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// for warmup does user set max duration, max iters, or both?
|
// for warmup does user set max duration, max iters, or both?
|
||||||
// warmup is skipped by default, i.e. warmup_iter = 0
|
// warmup is skipped by default, i.e. warmup_iter = 0
|
||||||
// warmup will be set to the non-zero value of max_warmup_duration
|
// warmup will be set to the non-zero value of max_warmup_duration
|
||||||
|
@ -87,3 +87,7 @@
|
|||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. autofunction:: get_rotating_buffer_size
|
.. autofunction:: get_rotating_buffer_size
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autofunction:: set_numerical_check_tolerances
|
||||||
|
```
|
@ -148,7 +148,6 @@ class TestLinalg(TestCase):
|
|||||||
# loop through a list of potentially used
|
# loop through a list of potentially used
|
||||||
# environment variables.
|
# environment variables.
|
||||||
env_list = ["PYTORCH_TUNABLEOP_BLAS_LOG",
|
env_list = ["PYTORCH_TUNABLEOP_BLAS_LOG",
|
||||||
"PYTORCH_TUNABLEOP_NUMERICAL_CHECK",
|
|
||||||
"PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
|
"PYTORCH_TUNABLEOP_UNTUNED_FILENAME"]
|
||||||
for env in env_list:
|
for env in env_list:
|
||||||
try:
|
try:
|
||||||
@ -168,6 +167,7 @@ class TestLinalg(TestCase):
|
|||||||
torch.cuda.tunable.set_max_tuning_duration(30)
|
torch.cuda.tunable.set_max_tuning_duration(30)
|
||||||
torch.cuda.tunable.set_max_tuning_iterations(100)
|
torch.cuda.tunable.set_max_tuning_iterations(100)
|
||||||
torch.cuda.tunable.set_rotating_buffer_size(-1)
|
torch.cuda.tunable.set_rotating_buffer_size(-1)
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(False)
|
||||||
ordinal = torch.cuda.current_device()
|
ordinal = torch.cuda.current_device()
|
||||||
|
|
||||||
# Set filenames to be unique on a per test basis
|
# Set filenames to be unique on a per test basis
|
||||||
@ -5144,7 +5144,6 @@ class TestLinalg(TestCase):
|
|||||||
@skipCUDAIfNotRocm
|
@skipCUDAIfNotRocm
|
||||||
@dtypes(torch.bfloat16)
|
@dtypes(torch.bfloat16)
|
||||||
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
|
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
|
||||||
import os
|
|
||||||
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
|
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
|
||||||
# run operator first without tuning to ensure all rocm libs are loaded,
|
# run operator first without tuning to ensure all rocm libs are loaded,
|
||||||
# otherwise false positive mem leak
|
# otherwise false positive mem leak
|
||||||
@ -5157,8 +5156,8 @@ class TestLinalg(TestCase):
|
|||||||
|
|
||||||
with self._tunableop_ctx():
|
with self._tunableop_ctx():
|
||||||
torch.cuda.tunable.set_rotating_buffer_size(0)
|
torch.cuda.tunable.set_rotating_buffer_size(0)
|
||||||
# enable tunableop numeric check via env variable.
|
# enable tunableop numeric check via API.
|
||||||
os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1"
|
torch.cuda.tunable.set_numerical_check_tolerances(True, 0.1, 0.1)
|
||||||
|
|
||||||
ordinal = torch.cuda.current_device()
|
ordinal = torch.cuda.current_device()
|
||||||
|
|
||||||
@ -6023,6 +6022,48 @@ class TestLinalg(TestCase):
|
|||||||
# There must be exactly three kernels only
|
# There must be exactly three kernels only
|
||||||
self.assertEqual(kernel_count, 3)
|
self.assertEqual(kernel_count, 3)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@skipCUDAIfNotRocm
|
||||||
|
@dtypes(torch.float16)
|
||||||
|
def test_numerical_check_python_binding_tunableop(self, device, dtype):
|
||||||
|
with self._tunableop_ctx():
|
||||||
|
torch.cuda.tunable.enable(True)
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(True)
|
||||||
|
|
||||||
|
a = torch.randn(128, 128, device='cuda')
|
||||||
|
b = torch.randn(128, 128, device='cuda')
|
||||||
|
|
||||||
|
_ = a @ b
|
||||||
|
|
||||||
|
with self._tunableop_ctx():
|
||||||
|
torch.cuda.tunable.enable(True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, r"positive"):
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, 1e5)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, r"positive"):
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(True, 1e-5, -1e5)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, r"positive"):
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(True, -1e-5, -1e5)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@skipCUDAIfNotRocm
|
||||||
|
@dtypes(torch.float16, torch.float32)
|
||||||
|
def test_numerical_check_accuracy_tunableop(self, device, dtype):
|
||||||
|
shapes = [(127, 193, 61), (251, 317, 73), (89, 149, 41)]
|
||||||
|
atol, rtol = 1e-2, 1e-1
|
||||||
|
|
||||||
|
for (m, k, n) in shapes:
|
||||||
|
a = torch.randn(m, k, device='cuda')
|
||||||
|
b = torch.randn(k, n, device='cuda')
|
||||||
|
torch.cuda.tunable.enable(False)
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(False)
|
||||||
|
C_baseline = a @ b
|
||||||
|
with self._tunableop_ctx():
|
||||||
|
torch.cuda.tunable.enable(True)
|
||||||
|
torch.cuda.tunable.set_numerical_check_tolerances(True, atol, rtol)
|
||||||
|
C_numeric = a @ b
|
||||||
|
self.assertTrue(torch.allclose(C_baseline, C_numeric, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
|
||||||
@dtypes(torch.float, torch.complex64)
|
@dtypes(torch.float, torch.complex64)
|
||||||
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
||||||
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
|
||||||
|
@ -2202,6 +2202,9 @@ def _cuda_tunableop_get_results() -> tuple[str, str, str, _float]: ...
|
|||||||
def _cuda_tunableop_get_validators() -> tuple[str, str]: ...
|
def _cuda_tunableop_get_validators() -> tuple[str, str]: ...
|
||||||
def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ...
|
def _cuda_tunableop_set_rotating_buffer_size(buffer_size: _int) -> None: ...
|
||||||
def _cuda_tunableop_get_rotation_buffer_size() -> _int: ...
|
def _cuda_tunableop_get_rotation_buffer_size() -> _int: ...
|
||||||
|
def _cuda_tunableop_set_numerical_check_tolerances(
|
||||||
|
enabled: _bool, atol: _float = 1e-5, rtol: _float = 1e-5
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
class _CudaDeviceProperties:
|
class _CudaDeviceProperties:
|
||||||
name: str
|
name: str
|
||||||
|
@ -1857,6 +1857,64 @@ PyObject* THCPModule_cuda_tunableop_get_rotating_buffer_size(
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* THCPModule_cuda_tunableop_set_numerical_check_tolerances(
|
||||||
|
PyObject* unused,
|
||||||
|
PyObject* args) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
|
||||||
|
PyObject* enabled_obj;
|
||||||
|
PyObject* atol_obj = NULL;
|
||||||
|
PyObject* rtol_obj = NULL;
|
||||||
|
|
||||||
|
// Parse: required bool, optional float, optional float
|
||||||
|
if (!PyArg_ParseTuple(args, "O|OO", &enabled_obj, &atol_obj, &rtol_obj)) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"cuda_tunableop_set_numerical_check_tolerances expects (bool[, float[, float]])");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
PyBool_Check(enabled_obj),
|
||||||
|
"First argument must be a boolean, got ",
|
||||||
|
THPUtils_typename(enabled_obj));
|
||||||
|
|
||||||
|
bool enabled = THPUtils_unpackBool(enabled_obj);
|
||||||
|
|
||||||
|
double atol = 1e-5;
|
||||||
|
double rtol = 1e-5;
|
||||||
|
|
||||||
|
if (atol_obj != NULL) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
PyFloat_Check(atol_obj),
|
||||||
|
"Second argument (atol) must be a float, got ",
|
||||||
|
THPUtils_typename(atol_obj));
|
||||||
|
|
||||||
|
atol = PyFloat_AsDouble(atol_obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rtol_obj != NULL) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
PyFloat_Check(rtol_obj),
|
||||||
|
"Third argument (rtol) must be a float, got ",
|
||||||
|
THPUtils_typename(rtol_obj));
|
||||||
|
|
||||||
|
rtol = PyFloat_AsDouble(rtol_obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
atol > 0.0 && rtol > 0.0,
|
||||||
|
"Numerical check tolerances must be positive. Got atol=",
|
||||||
|
atol,
|
||||||
|
", rtol=",
|
||||||
|
rtol);
|
||||||
|
|
||||||
|
at::cuda::tunable::getTuningContext()->SetNumericalCheckConfig(
|
||||||
|
enabled, atol, rtol);
|
||||||
|
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* THCPModule_isCurrentStreamCapturing_wrap(
|
static PyObject* THCPModule_isCurrentStreamCapturing_wrap(
|
||||||
PyObject* self,
|
PyObject* self,
|
||||||
PyObject* noargs) {
|
PyObject* noargs) {
|
||||||
@ -2131,6 +2189,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
|
|||||||
THCPModule_cuda_tunableop_get_rotating_buffer_size,
|
THCPModule_cuda_tunableop_get_rotating_buffer_size,
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_cuda_tunableop_set_numerical_check_tolerances",
|
||||||
|
THCPModule_cuda_tunableop_set_numerical_check_tolerances,
|
||||||
|
METH_VARARGS,
|
||||||
|
nullptr},
|
||||||
{nullptr}};
|
{nullptr}};
|
||||||
|
|
||||||
PyMethodDef* THCPModule_methods() {
|
PyMethodDef* THCPModule_methods() {
|
||||||
|
@ -211,6 +211,7 @@ __all__ = [
|
|||||||
"mgpu_tune_gemm_in_file",
|
"mgpu_tune_gemm_in_file",
|
||||||
"set_rotating_buffer_size",
|
"set_rotating_buffer_size",
|
||||||
"get_rotating_buffer_size",
|
"get_rotating_buffer_size",
|
||||||
|
"set_numerical_check_tolerances",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -327,6 +328,13 @@ def get_rotating_buffer_size() -> int:
|
|||||||
return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined]
|
return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def set_numerical_check_tolerances(
|
||||||
|
enable: bool, atol: float = 1e-5, rtol: float = 1e-5
|
||||||
|
) -> None:
|
||||||
|
r"""Set the atol and rtol values in numeric check"""
|
||||||
|
return torch._C._cuda_tunableop_set_numerical_check_tolerances(enable, atol, rtol) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def tune_gemm_in_file(filename: str) -> None:
|
def tune_gemm_in_file(filename: str) -> None:
|
||||||
r"""tune GEMM in file."""
|
r"""tune GEMM in file."""
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user