mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)
Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
e787d532b6
commit
66ea76ec44
@ -13,6 +13,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/tunable/TunableOp.h>
|
||||
#include <ATen/cuda/tunable/Tunable.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
|
||||
BLASType = "unknown";
|
||||
}
|
||||
return BLASType;
|
||||
|
||||
}
|
||||
|
||||
// Similar to Compute Type in GemmRocblas.h
|
||||
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
|
||||
|
||||
namespace detail {
|
||||
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
|
||||
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
|
||||
|
||||
if (!config.enabled) {
|
||||
return true; // skip when disabled
|
||||
}
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
|
||||
// comparison done as 1D tensor
|
||||
at::Tensor ref = at::from_blob(c, {size}, options);
|
||||
at::Tensor oth = at::from_blob(other_c, {size}, options);
|
||||
at::Tensor ref_float = ref.to(at::kFloat);
|
||||
at::Tensor oth_float = oth.to(at::kFloat);
|
||||
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
|
||||
double last_succeed_atol = 1;
|
||||
double last_succeed_rtol = 1;
|
||||
for (auto& atol : atols) {
|
||||
for (auto& rtol : rtols) {
|
||||
if (at::allclose(ref_float, oth_float, rtol, atol)) {
|
||||
last_succeed_atol = atol;
|
||||
last_succeed_rtol = rtol;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_succeed_atol == 1) {
|
||||
return false;
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
|
||||
}
|
||||
|
||||
return true;
|
||||
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
|
||||
if (ok) {
|
||||
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
} else {
|
||||
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
|
||||
}
|
||||
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<T>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
|
||||
}
|
||||
|
||||
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
|
||||
auto* ctx = getTuningContext();
|
||||
auto cfg = ctx->GetNumericalCheckConfig();
|
||||
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
|
||||
}
|
||||
|
||||
char transa{};
|
||||
|
@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
|
||||
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
|
||||
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
|
||||
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
|
||||
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
|
||||
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
|
||||
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
|
||||
@ -173,6 +173,7 @@ All python APIs exist in the `torch.cuda.tunable` module.
|
||||
| get_max_tuning_iterations() -> int | |
|
||||
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
|
||||
| get_filename() -> str | |
|
||||
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
|
||||
| get_results() -> Tuple[str, str, str, float] | |
|
||||
| get_validators() -> Tuple[str, str] | |
|
||||
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
|
||||
|
@ -590,12 +590,49 @@ void TuningContext::EnableNumericsCheck(bool value) {
|
||||
numerics_check_enable_ = value;
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
if (env == "1") {
|
||||
return true;
|
||||
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
|
||||
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
|
||||
|
||||
if (!env_opt.has_value()) {
|
||||
return numerics_cfg_;
|
||||
}
|
||||
return numerics_check_enable_;
|
||||
|
||||
const std::string& env = env_opt.value();
|
||||
|
||||
if (env == "0") {
|
||||
return NumericalCheckConfig(false, 1e-5, 1e-5);
|
||||
}
|
||||
|
||||
const size_t underscore = env.find('_');
|
||||
|
||||
TORCH_CHECK(
|
||||
underscore != std::string::npos,
|
||||
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
|
||||
"Expected 'atol_rtol', got: ",
|
||||
env);
|
||||
|
||||
double atol = 0.0;
|
||||
double rtol = 0.0;
|
||||
|
||||
try {
|
||||
atol = std::stod(env.substr(0, underscore));
|
||||
rtol = std::stod(env.substr(underscore + 1));
|
||||
} catch (const std::exception& e) {
|
||||
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
|
||||
}
|
||||
|
||||
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
|
||||
return NumericalCheckConfig(true, atol, rtol);
|
||||
}
|
||||
|
||||
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
|
||||
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
|
||||
numerics_cfg_ = {enabled, atol, rtol};
|
||||
}
|
||||
|
||||
bool TuningContext::IsNumericsCheckEnabled() const {
|
||||
const auto cfg = GetNumericalCheckConfig();
|
||||
return cfg.enabled || numerics_check_enable_;
|
||||
}
|
||||
|
||||
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
|
||||
|
@ -148,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
|
||||
GetValidateFuncs validators_;
|
||||
};
|
||||
|
||||
struct NumericalCheckConfig {
|
||||
bool enabled{false};
|
||||
double atol{1e-5};
|
||||
double rtol{1e-5};
|
||||
|
||||
NumericalCheckConfig() = default;
|
||||
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
|
||||
};
|
||||
|
||||
|
||||
class TORCH_CUDA_CPP_API TuningContext {
|
||||
public:
|
||||
TuningContext();
|
||||
@ -169,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
|
||||
void EnableNumericsCheck(bool value);
|
||||
bool IsNumericsCheckEnabled() const;
|
||||
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
|
||||
NumericalCheckConfig GetNumericalCheckConfig() const;
|
||||
|
||||
void SetMaxTuningDurationMs(int max_duration_ms);
|
||||
int GetMaxTuningDurationMs() const;
|
||||
@ -232,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
|
||||
std::ofstream untuned_file_;
|
||||
size_t results_count_from_input_file_;
|
||||
bool is_shutting_down_;
|
||||
|
||||
NumericalCheckConfig numerics_cfg_{};
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API TuningContext* getTuningContext();
|
||||
|
@ -267,27 +267,10 @@ class TunableOp {
|
||||
for (size_t i = 0; i < op_names_.size(); i++) {
|
||||
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
auto status = candidate->Call(reusable_params[0]);
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
|
||||
// collect a small profile
|
||||
@ -310,6 +293,22 @@ class TunableOp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (do_numerics_check) {
|
||||
ParamsT* numerical_params = params->DeepCopy(false);
|
||||
auto status = candidate->Call(numerical_params);
|
||||
if (status != OK) {
|
||||
numerical_params->Delete();
|
||||
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
status = reference_params->NumericalCheck(numerical_params);
|
||||
numerical_params->Delete();
|
||||
if (status != OK) {
|
||||
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// for warmup does user set max duration, max iters, or both?
|
||||
// warmup is skipped by default, i.e. warmup_iter = 0
|
||||
// warmup will be set to the non-zero value of max_warmup_duration
|
||||
|
Reference in New Issue
Block a user