mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it. ### Design Choice: Directly use algorithms name like "TF32", "BF16". #### Pros - The names are more informative. 'tf32' is more informative than a simple "high". - Easier to extend new algorithm like `tf32x3` #### Cons - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them. ### We provide a layered structure for backends/operators. ('f32' is short for 'fp32_precision')  ### We provide 3 fp32 compute precision can be set: - **"ieee"**: Not allowed to use any other internal computation data types . - **"tf32"**: Allowed to use tf32 as internal computation data types. - **"bf16"**: Allowed to use bf16 as internal computation data types. - **"none"**: Precision's are not set. Can be override by its father node. ### Overriding Precision Settings Child node can be override by its father node if it is set to default. For current default settings: ``` backend = generic, op = all, precision setting = none backend = cuda, op = all, precision setting = none backend = cuda, op = conv, precision setting = tf32 backend = cuda, op = rnn, precision setting = tf32 backend = cuda, op = matmul, precision setting = none backend = matmul, op = all, precision setting = none backend = matmul, op = conv, precision setting = none backend = matmul, op = rnn, precision setting = none backend = matmul, op = matmul, precision setting = none ``` - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16". - If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16". ### Backward Compatible Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is - If the user only uses previous APIs, it will work as previous expectations. - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user. ### Test Plan ``` python test/test_cuda.py -k test_fp32_precision_with_tf32 python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision python test/test_cuda.py -k test_invalid_status_for_legacy_api python test/test_mkldnn.py -k test_mlkdnn_get_set python test/test_mkldnn.py -k test_generic_precision python test/test_mkldnn.py -k test_invalid python test/test_mkldnn.py -k test_default_use_parent ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888 Approved by: https://github.com/jgong5, https://github.com/albanD Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
de45c5f673
commit
53e0b9c393
@ -19,9 +19,69 @@
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE)
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means
|
||||
IEEE standard floating point format "tf32" and "bf16" means we are allowed to
|
||||
use "tf32" or "bf16" as internal computation data types for fp32 computations.
|
||||
And "none" means it is override-able by parent's node
|
||||
|
||||
generic->mkldnn->matmul
|
||||
->conv
|
||||
->rnn
|
||||
->cuda ->matmul
|
||||
->conv
|
||||
->rnn
|
||||
*/
|
||||
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
|
||||
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
|
||||
{"mkldnn", {{"ieee", "bf16", "none"}}},
|
||||
{"cuda", {{"ieee", "tf32", "none"}}}};
|
||||
|
||||
// Check whether the backend and op are legal
|
||||
void check_fp32_prec_backend_and_op(
|
||||
const std::string& backend,
|
||||
const std::string& op) {
|
||||
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
|
||||
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
|
||||
TORCH_CHECK(
|
||||
std::find(backends.begin(), backends.end(), backend) != backends.end(),
|
||||
"Invalid backend: ",
|
||||
backend);
|
||||
TORCH_CHECK(
|
||||
std::find(operators.begin(), operators.end(), op) != operators.end(),
|
||||
"Invalid operator: ",
|
||||
op);
|
||||
if (backend == "generic") {
|
||||
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
|
||||
}
|
||||
}
|
||||
|
||||
// Return whether the precision is supported by backends
|
||||
bool validate_fp32_prec(
|
||||
const std::string& backend,
|
||||
const std::string& precision) {
|
||||
auto iterp = _fp32_precisions.find(backend);
|
||||
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||
auto precisions = iterp->second;
|
||||
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
|
||||
precisions.end();
|
||||
return valid;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"This API is going to be deprecated, please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Context::Context() = default;
|
||||
|
||||
// TODO: This could be bad juju if someone calls globalContext() in the
|
||||
@ -115,12 +175,29 @@ void Context::setUserEnabledNNPACK(bool e) {
|
||||
enabled_nnpack = e;
|
||||
}
|
||||
|
||||
bool Context::allowTF32CuDNN() const {
|
||||
bool Context::allowTF32CuDNN(const std::string& op) const {
|
||||
if (op.size() == 0){
|
||||
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
|
||||
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
|
||||
TORCH_CHECK(
|
||||
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
|
||||
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
|
||||
"but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.",
|
||||
"This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ",
|
||||
"We suggest only using the new API to set the TF32 flag(s). See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
} else {
|
||||
return float32Precision("cuda", op) == "tf32";
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
|
||||
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -259,7 +336,16 @@ bool Context::allowTF32CuBLAS() const {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
||||
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
||||
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
|
||||
TORCH_CHECK(
|
||||
legacy_allow_tf32 == allow_tf32_new,
|
||||
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
void Context::setAllowTF32CuBLAS(bool b) {
|
||||
@ -272,27 +358,54 @@ void Context::setAllowTF32CuBLAS(bool b) {
|
||||
}
|
||||
#endif
|
||||
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
|
||||
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
|
||||
}
|
||||
|
||||
Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
|
||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
|
||||
invalid = invalid ||
|
||||
(float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
|
||||
TORCH_CHECK(
|
||||
!invalid,
|
||||
"PyTorch is checking the matmul precision without a specific backend name,",
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
|
||||
float32_matmul_precision = p;
|
||||
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
|
||||
check_fp32_prec_backend_and_op(backend, op);
|
||||
auto precision = fp32_precision.find(backend)->second.find(op)->second;
|
||||
if (precision == "none")
|
||||
precision = fp32_precision.find(backend)->second.find("all")->second;
|
||||
if (precision == "none")
|
||||
precision = fp32_precision.find("generic")->second.find("all")->second;
|
||||
bool valid_prec = validate_fp32_prec(backend, precision);
|
||||
return valid_prec ? precision : "none";
|
||||
}
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
setFloat32Precision("cuda", "matmul", "ieee");
|
||||
setFloat32Precision("mkldnn", "matmul", "ieee");
|
||||
return true;
|
||||
} else if (s_ == "high") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
|
||||
setFloat32Precision("cuda", "matmul", "tf32");
|
||||
setFloat32Precision("mkldnn", "matmul", "ieee");
|
||||
return true;
|
||||
} else if (s_ == "medium") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
|
||||
setFloat32Precision("cuda", "matmul", "tf32");
|
||||
setFloat32Precision("mkldnn", "matmul", "bf16");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -306,6 +419,27 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
"setFloat32MatmulPrecision call has no effect.");
|
||||
}
|
||||
|
||||
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
|
||||
check_fp32_prec_backend_and_op(backend, op);
|
||||
if (validate_fp32_prec(backend, p)) {
|
||||
fp32_precision[backend][op] = p;
|
||||
} else {
|
||||
std::string msg;
|
||||
auto iterp = _fp32_precisions.find(backend);
|
||||
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||
for (auto p : iterp->second) {
|
||||
msg += p;
|
||||
msg += " ";
|
||||
}
|
||||
TORCH_WARN(
|
||||
"you have set wrong precision for backend:",
|
||||
backend,
|
||||
" setFloat32Precision call has no effect.",
|
||||
"Please choose precision from: ",
|
||||
msg);
|
||||
}
|
||||
}
|
||||
|
||||
at::LinalgBackend Context::linalgPreferredBackend() const {
|
||||
return linalg_preferred_backend;
|
||||
}
|
||||
|
@ -28,6 +28,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
@ -336,14 +337,20 @@ class TORCH_API Context {
|
||||
void alertCuBLASConfigNotDeterministic() const;
|
||||
|
||||
void setFloat32MatmulPrecision(const std::string& s);
|
||||
bool allowTF32CuDNN() const;
|
||||
void setFloat32Precision(
|
||||
const std::string& backend,
|
||||
const std::string& op,
|
||||
const std::string& s);
|
||||
bool allowTF32CuDNN(const std::string& op = std::string()) const;
|
||||
void setAllowTF32CuDNN(bool);
|
||||
bool allowTF32OneDNN() const;
|
||||
void setAllowTF32OneDNN(bool);
|
||||
bool allowTF32CuBLAS() const;
|
||||
void setAllowTF32CuBLAS(bool);
|
||||
Float32MatmulPrecision float32MatmulPrecision() const;
|
||||
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
|
||||
std::string float32Precision(
|
||||
const std::string& backend,
|
||||
const std::string& op) const;
|
||||
bool allowFP16ReductionCuBLAS() const;
|
||||
void setAllowFP16ReductionCuBLAS(bool);
|
||||
bool allowBF16ReductionCuBLAS() const;
|
||||
@ -469,6 +476,23 @@ class TORCH_API Context {
|
||||
bool enable_sparse_tensor_invariant_checks = false;
|
||||
bool allow_fp16_reduction_cpu = false;
|
||||
|
||||
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
|
||||
{"generic", {{"all", "none"}}},
|
||||
{"mkldnn",
|
||||
{{"matmul", "none"},
|
||||
{"conv", "none"},
|
||||
{"rnn", "none"},
|
||||
{"all", "none"}}},
|
||||
{"cuda",
|
||||
{{"matmul",
|
||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
|
||||
? "none"
|
||||
: "tf32"},
|
||||
{"conv", "tf32"},
|
||||
{"rnn", "tf32"},
|
||||
{"all", "none"}}},
|
||||
};
|
||||
|
||||
Allocator* prev_allocator_ptr_{nullptr};
|
||||
};
|
||||
|
||||
|
@ -407,7 +407,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
scaleType = CUDA_R_64F;
|
||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
|
||||
@ -1589,7 +1589,7 @@ bool gemm_and_bias(
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
scaleType = CUDA_R_64F;
|
||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
|
@ -218,7 +218,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
|
||||
if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
|
||||
if (!NoTF32Guard::should_disable_tf32() &&
|
||||
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
|
@ -160,7 +160,7 @@ inline std::string ComputeTypeFor() {
|
||||
// ROCBLAS and hipBLASLt.
|
||||
template <>
|
||||
inline std::string ComputeTypeFor<float>() {
|
||||
if (!at::globalContext().allowTF32CuBLAS()) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
|
||||
return "f32_r";
|
||||
} else {
|
||||
return "xf32_r";
|
||||
|
@ -499,7 +499,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
}
|
||||
|
||||
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
||||
|
@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
|
||||
|
||||
TuningStatus Call(const GemmParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
||||
|
||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
|
@ -1174,7 +1174,7 @@ at::Tensor convolution(
|
||||
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
||||
transposed, output_padding, groups,
|
||||
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
|
||||
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
|
||||
}
|
||||
|
||||
at::Tensor convolution_overrideable(
|
||||
@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
|
||||
params.benchmark = ctx.benchmarkCuDNN();
|
||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
|
||||
auto input = input_r;
|
||||
auto weight = weight_r;
|
||||
@ -1705,7 +1705,7 @@ at::Tensor _convolution(
|
||||
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||
|
||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
|
||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
@ -2003,7 +2003,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
|
||||
params.benchmark = ctx.benchmarkCuDNN();
|
||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
|
||||
// Validate inputs.
|
||||
check_shape_backward(input, weight.sizes(), params);
|
||||
|
@ -169,7 +169,8 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
||||
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
|
||||
ss << "import torch\n";
|
||||
ss << "torch.backends.cuda.matmul.allow_tf32 = "
|
||||
<< pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
|
||||
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
|
||||
<< "\n";
|
||||
ss << "torch.backends.cudnn.benchmark = "
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
||||
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
|
||||
@ -725,7 +726,7 @@ Tensor cudnn_convolution_relu(
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
@ -783,7 +784,7 @@ Tensor cudnn_convolution_add_relu(
|
||||
}
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias_t.has_value()
|
||||
|
@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
||||
datatype,
|
||||
input_datatype,
|
||||
algo,
|
||||
at::globalContext().allowTF32CuDNN());
|
||||
at::globalContext().allowTF32CuDNN("rnn"));
|
||||
#else
|
||||
rnn_desc.set(
|
||||
handle,
|
||||
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
||||
datatype,
|
||||
input_datatype,
|
||||
algo,
|
||||
at::globalContext().allowTF32CuDNN());
|
||||
at::globalContext().allowTF32CuDNN("rnn"));
|
||||
#endif
|
||||
return rnn_desc;
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ static bool use_mkldnn_fp16_matmul() {
|
||||
}
|
||||
|
||||
static bool use_mkldnn_bf32_matmul() {
|
||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
|
||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
|
||||
}
|
||||
|
||||
// returns an ideep::tensor
|
||||
|
@ -133,6 +133,44 @@ To toggle the TF32 flags off in C++, you can do
|
||||
at::globalContext().setAllowTF32CuBLAS(false);
|
||||
at::globalContext().setAllowTF32CuDNN(false);
|
||||
|
||||
After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way.
|
||||
We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator.
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.fp32_precision = "ieee"
|
||||
torch.backends.cuda.matmul.fp32_precision = "ieee"
|
||||
torch.backends.cudnn.fp32_precision = "ieee"
|
||||
torch.backends.cudnn.conv.fp32_precision = "tf32"
|
||||
torch.backends.cudnn.rnn.fp32_precision = "tf32"
|
||||
|
||||
The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`.
|
||||
`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision.
|
||||
`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision.
|
||||
|
||||
We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`.
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.cudnn.fp32_precision = "tf32"
|
||||
torch.backends.cudnn.conv.fp32_precision = "ieee"
|
||||
torch.backends.cudnn.rnn.fp32_precision = "ieee"
|
||||
|
||||
We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`.
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.fp32_precision = "tf32"
|
||||
torch.backends.cudnn.fp32_precision = "ieee"
|
||||
torch.backends.cudnn.conv.fp32_precision = "ieee"
|
||||
torch.backends.cudnn.rnn.fp32_precision = "ieee"
|
||||
|
||||
For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision`
|
||||
is overridden to `ieee`.
|
||||
|
||||
Old settings are still supported. But we suggest to use the new settings for better control. And we do not support
|
||||
to use mix of old and new settings.
|
||||
|
||||
For more information about TF32, see:
|
||||
|
||||
- `TensorFloat-32`_
|
||||
|
102
docs/source/notes/mkldnn.rst
Normal file
102
docs/source/notes/mkldnn.rst
Normal file
@ -0,0 +1,102 @@
|
||||
.. meta::
|
||||
:description: A guide to torch.backends.mkldnn, a PyTorch backend to run MKLDNN operations
|
||||
:keywords: optimize PyTorch, MKLDNN
|
||||
|
||||
.. _mkldnn_backend:
|
||||
|
||||
MKLDNN backend
|
||||
---------------------------------------------------
|
||||
|
||||
MKLDNN is an open-source cross-platform performance library of basic building blocks
|
||||
for deep learning applications.
|
||||
|
||||
.. code:: python
|
||||
|
||||
# The flag below controls whether enable MKLDNN backend in Pytorch.
|
||||
torch.backends.mkldnn.enabled = True
|
||||
|
||||
Users can disable MKLDNN backend by:
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.mkldnn.enabled = False
|
||||
|
||||
.. _bf16_on_mkldnn:
|
||||
|
||||
Bfloat16 (BF16) on MKLDNN backend
|
||||
---------------------------------------------------
|
||||
|
||||
Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision
|
||||
for `float32` operators.
|
||||
|
||||
.. code:: python
|
||||
|
||||
# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
|
||||
|
||||
# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
|
||||
torch.backends.mkldnn.conv.fp32_precision = "ieee"
|
||||
|
||||
# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
|
||||
torch.backends.mkldnn.rnn.fp32_precision = "ieee"
|
||||
|
||||
Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses
|
||||
matmuls or convolutions are also affected. These include :class:`torch.nn.Linear`, :class:`torch.nn._ConvNd`, :func:`torch.cdist`,
|
||||
:func:`torch.tensordot`, :func:`torch.nn.functional.affine_grid` and :func:`torch.nn.functional.grid_sample`,
|
||||
:class:`torch.nn.AdaptiveLogSoftmaxWithLoss`, :class:`torch.nn.GRU` and :class:`torch.nn.LSTM`.
|
||||
|
||||
To get an idea of the precision and speed, see the example code and benchmark data (on SPR) below:
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.manual_seed(0)
|
||||
a_full = torch.randn(10240, 10240, dtype=torch.double)
|
||||
b_full = torch.randn(10240, 10240, dtype=torch.double)
|
||||
ab_full = a_full @ b_full
|
||||
mean = ab_full.abs().mean() # 80.7451
|
||||
|
||||
a = a_full.float()
|
||||
b = b_full.float()
|
||||
|
||||
# Do matmul at BF16 mode.
|
||||
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
|
||||
ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration
|
||||
error = (ab_bf16 - ab_full).abs().max() # 1.3704
|
||||
relative_error = error / mean # 0.0170
|
||||
print(error, relative_error)
|
||||
|
||||
# Do matmul FP32 mode.
|
||||
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
|
||||
ab_fp32 = a @ b
|
||||
error = (ab_fp32 - ab_full).abs().max() # 0.0003
|
||||
relative_error = error / mean # 0.00000317
|
||||
print(error, relative_error)
|
||||
|
||||
From the above example, we can see that with BF16, the speed is ~7x faster on SPR, and that
|
||||
relative error compared to double precision is approximately 2 orders of magnitude larger.
|
||||
If full FP32 precision is needed, users can disable BF16 by:
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
|
||||
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
|
||||
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'
|
||||
|
||||
To toggle the BF16 flags off in C++, you can do
|
||||
|
||||
.. code:: C++
|
||||
|
||||
at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
|
||||
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
|
||||
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");
|
||||
|
||||
We can override a generic setting for a specific operator or backend if the fp32_precision is set to `ieee`.
|
||||
|
||||
.. code:: python
|
||||
|
||||
torch.backends.fp32_precision = "bf16"
|
||||
torch.backends.mkldnn.fp32_precision = "ieee"
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
|
||||
|
||||
For such case, both `torch.backends.mkldnn.fp32_precision` and `torch.backends.mkldnn.matmul.fp32_precision`
|
||||
is overridden to bf16.
|
@ -30,7 +30,13 @@ from torch.testing._internal.common_device_type import (
|
||||
|
||||
|
||||
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
|
||||
torch.set_float32_matmul_precision("high")
|
||||
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
|
||||
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
|
||||
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
|
||||
if torch.version.hip:
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
else:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
index = torch.ops.aten.index
|
||||
Tensor = torch.Tensor
|
||||
|
@ -97,7 +97,13 @@ class TestCaseBase(TestCase):
|
||||
if HAS_GPU:
|
||||
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
|
||||
cls.prior_default_device = torch.get_default_device()
|
||||
torch.set_float32_matmul_precision("high")
|
||||
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
|
||||
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
|
||||
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
|
||||
if torch.version.hip:
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
else:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.set_default_device(GPU_TYPE)
|
||||
|
||||
@classmethod
|
||||
|
@ -69,6 +69,7 @@ from torch.testing._internal.common_utils import (
|
||||
load_tests,
|
||||
MI300_ARCH,
|
||||
parametrize,
|
||||
recover_orig_fp32_precision,
|
||||
run_tests,
|
||||
serialTest,
|
||||
setBlasBackendsToDefaultFinally,
|
||||
@ -849,6 +850,55 @@ print(t.is_pinned())
|
||||
):
|
||||
self.assertTrue(torch.backends.cudnn.allow_tf32)
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_fp32_precision_with_tf32(self):
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=None,
|
||||
benchmark=None,
|
||||
benchmark_limit=None,
|
||||
deterministic=None,
|
||||
allow_tf32=True,
|
||||
fp32_precision="none",
|
||||
):
|
||||
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "tf32")
|
||||
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "tf32")
|
||||
|
||||
with torch.backends.cudnn.flags(
|
||||
enabled=None,
|
||||
benchmark=None,
|
||||
benchmark_limit=None,
|
||||
deterministic=None,
|
||||
allow_tf32=False,
|
||||
fp32_precision="none",
|
||||
):
|
||||
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "none")
|
||||
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "none")
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_fp32_precision_with_float32_matmul_precision(self):
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "ieee")
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_invalid_status_for_legacy_api(self):
|
||||
torch.backends.cudnn.conv.fp32_precision = "none"
|
||||
torch.backends.cudnn.rnn.fp32_precision = "tf32"
|
||||
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
|
||||
print(torch.backends.cudnn.allow_tf32)
|
||||
|
||||
torch.set_float32_matmul_precision("highest")
|
||||
torch.backends.cuda.matmul.fp32_precision = "tf32"
|
||||
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
|
||||
print(torch.get_float32_matmul_precision())
|
||||
|
||||
if not TEST_WITH_ROCM:
|
||||
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
|
||||
print(torch.backends.cuda.matmul.allow_tf32)
|
||||
|
||||
def test_type_conversions(self):
|
||||
x = torch.randn(5, 5)
|
||||
self.assertIsInstance(x.float(), torch.FloatTensor)
|
||||
|
@ -22,7 +22,7 @@ import torch.backends.mkldnn
|
||||
from torch.utils import mkldnn as mkldnn_utils
|
||||
from torch.testing._internal.common_utils import TestCase, \
|
||||
run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
|
||||
skipIfTorchDynamo, xfailIfTorchDynamo
|
||||
skipIfTorchDynamo, xfailIfTorchDynamo, recover_orig_fp32_precision
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes,
|
||||
@ -1659,6 +1659,53 @@ class TestMkldnn(TestCase):
|
||||
self.assertEqual(out_emulated.float(), out.float(), atol=5e-2, rtol=5e-2)
|
||||
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_mlkdnn_get_set(self):
|
||||
# get/set mkldnn ops
|
||||
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
|
||||
self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16")
|
||||
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
|
||||
self.assertEqual(torch.backends.mkldnn.fp32_precision, "none")
|
||||
# get/set matmul
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "none"
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
|
||||
# get/set conv
|
||||
torch.backends.mkldnn.conv.fp32_precision = "bf16"
|
||||
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16")
|
||||
torch.backends.mkldnn.conv.fp32_precision = "none"
|
||||
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none")
|
||||
# get/set rnn
|
||||
torch.backends.mkldnn.rnn.fp32_precision = "bf16"
|
||||
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16")
|
||||
torch.backends.mkldnn.rnn.fp32_precision = "none"
|
||||
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none")
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_generic_precision(self):
|
||||
with torch.backends.flags(fp32_precision="none"):
|
||||
self.assertEqual(torch.backends.fp32_precision, "none")
|
||||
with torch.backends.flags(fp32_precision="tf32"):
|
||||
self.assertEqual(torch.backends.fp32_precision, "tf32")
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_default_use_parent(self):
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "none"
|
||||
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
|
||||
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
|
||||
with torch.backends.flags(fp32_precision="bf16"):
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
|
||||
with torch.backends.flags(fp32_precision="tf32"):
|
||||
# when parent is a not supported precision, use default
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
|
||||
|
||||
@recover_orig_fp32_precision
|
||||
def test_invalid(self):
|
||||
# use default if user set a not supported precision
|
||||
torch.backends.mkldnn.matmul.fp32_precision = "tf32"
|
||||
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
|
||||
|
||||
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))
|
||||
|
||||
|
@ -1360,6 +1360,8 @@ def _disabled_torch_dispatch_impl(
|
||||
) -> Any: ... # THPModule_disable_dispatch_function
|
||||
def _get_linalg_preferred_backend() -> _LinalgBackend: ...
|
||||
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
|
||||
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
|
||||
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
|
||||
|
||||
class _LinalgBackend:
|
||||
Default: _LinalgBackend
|
||||
|
@ -254,7 +254,9 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
cuda_rng_state = None
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
allow_tf32 = torch._C._get_cublas_allow_tf32()
|
||||
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
|
||||
"cuda", "matmul"
|
||||
)
|
||||
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
|
||||
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
|
||||
cleanup = setup_compile_debug()
|
||||
@ -286,7 +288,9 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
torch._C._unset_default_mobile_cpu_allocator()
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
torch._C._set_cublas_allow_tf32(allow_tf32)
|
||||
torch._C._set_fp32_precision_setter(
|
||||
"cuda", "matmul", cuda_matmul_fp32_prec
|
||||
)
|
||||
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
|
||||
assert guards.check(), (
|
||||
f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
|
||||
|
@ -1,7 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# The idea for this parameter is that we forbid bare assignment
|
||||
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
|
||||
@ -57,6 +60,70 @@ class PropModule(types.ModuleType):
|
||||
return self.m.__getattribute__(attr)
|
||||
|
||||
|
||||
class _FP32Precision:
|
||||
def __init__(self, backend, op):
|
||||
self.backend = backend
|
||||
self.op = op
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "fp32_precision":
|
||||
torch._C._set_fp32_precision_setter(self.backend, self.op, value)
|
||||
elif name in ("backend", "op"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "fp32_precision":
|
||||
return torch._C._get_fp32_precision_getter(self.backend, self.op)
|
||||
else:
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
|
||||
def set_flags(_fp32_precision="none"):
|
||||
orig_flags = (torch._C._get_fp32_precision_getter("generic", "all"),)
|
||||
if _fp32_precision is not None:
|
||||
torch._C._set_fp32_precision_setter("generic", "all", _fp32_precision)
|
||||
return orig_flags
|
||||
|
||||
|
||||
@contextmanager
|
||||
def flags(fp32_precision="none"):
|
||||
with __allow_nonbracketed_mutation():
|
||||
orig_flags = set_flags(fp32_precision)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with __allow_nonbracketed_mutation():
|
||||
set_flags(*orig_flags)
|
||||
|
||||
|
||||
def _get_fp32_precision_getter(backend, op):
|
||||
def inner():
|
||||
return torch._C._get_fp32_precision_getter(backend, op)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _set_fp32_precision_setter(backend, op):
|
||||
def inner(precision):
|
||||
return torch._C._set_fp32_precision_setter(backend, op, precision)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class GenericModule(PropModule):
|
||||
def __init__(self, m, name):
|
||||
super().__init__(m, name)
|
||||
|
||||
fp32_precision = ContextProp(
|
||||
_get_fp32_precision_getter("generic", "all"),
|
||||
_set_fp32_precision_setter("generic", "all"),
|
||||
)
|
||||
|
||||
|
||||
sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__)
|
||||
|
||||
from torch.backends import (
|
||||
cpu as cpu,
|
||||
cuda as cuda,
|
||||
|
@ -135,6 +135,8 @@ class cuBLASModule:
|
||||
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._get_cublas_allow_fp16_accumulation()
|
||||
elif name == "fp32_precision":
|
||||
return torch._C._get_fp32_precision_getter("cuda", "matmul")
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@ -146,6 +148,8 @@ class cuBLASModule:
|
||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
|
||||
elif name == "allow_fp16_accumulation":
|
||||
return torch._C._set_cublas_allow_fp16_accumulation(value)
|
||||
elif name == "fp32_precision":
|
||||
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
|
||||
raise AttributeError("Unknown attribute " + name)
|
||||
|
||||
|
||||
|
@ -6,7 +6,14 @@ from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
||||
from torch.backends import (
|
||||
__allow_nonbracketed_mutation,
|
||||
_FP32Precision,
|
||||
_get_fp32_precision_getter,
|
||||
_set_fp32_precision_setter,
|
||||
ContextProp,
|
||||
PropModule,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@ -128,6 +135,7 @@ def set_flags(
|
||||
_benchmark_limit=None,
|
||||
_deterministic=None,
|
||||
_allow_tf32=None,
|
||||
_fp32_precision="none",
|
||||
):
|
||||
orig_flags = (
|
||||
torch._C._get_cudnn_enabled(),
|
||||
@ -135,6 +143,7 @@ def set_flags(
|
||||
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
|
||||
torch._C._get_cudnn_deterministic(),
|
||||
torch._C._get_cudnn_allow_tf32(),
|
||||
torch._C._get_fp32_precision_getter("cuda", "all"),
|
||||
)
|
||||
if _enabled is not None:
|
||||
torch._C._set_cudnn_enabled(_enabled)
|
||||
@ -146,6 +155,8 @@ def set_flags(
|
||||
torch._C._set_cudnn_deterministic(_deterministic)
|
||||
if _allow_tf32 is not None:
|
||||
torch._C._set_cudnn_allow_tf32(_allow_tf32)
|
||||
if _fp32_precision is not None:
|
||||
torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision)
|
||||
return orig_flags
|
||||
|
||||
|
||||
@ -156,10 +167,16 @@ def flags(
|
||||
benchmark_limit=10,
|
||||
deterministic=False,
|
||||
allow_tf32=True,
|
||||
fp32_precision="none",
|
||||
):
|
||||
with __allow_nonbracketed_mutation():
|
||||
orig_flags = set_flags(
|
||||
enabled, benchmark, benchmark_limit, deterministic, allow_tf32
|
||||
enabled,
|
||||
benchmark,
|
||||
benchmark_limit,
|
||||
deterministic,
|
||||
allow_tf32,
|
||||
fp32_precision,
|
||||
)
|
||||
try:
|
||||
yield
|
||||
@ -194,6 +211,12 @@ class CudnnModule(PropModule):
|
||||
allow_tf32 = ContextProp(
|
||||
torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
|
||||
)
|
||||
conv = _FP32Precision("cuda", "conv")
|
||||
rnn = _FP32Precision("cuda", "rnn")
|
||||
fp32_precision = ContextProp(
|
||||
_get_fp32_precision_getter("cuda", "all"),
|
||||
_set_fp32_precision_setter("cuda", "all"),
|
||||
)
|
||||
|
||||
|
||||
# This is the sys.modules replacement trick, see
|
||||
|
@ -4,7 +4,14 @@ from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
|
||||
from torch.backends import (
|
||||
__allow_nonbracketed_mutation,
|
||||
_FP32Precision,
|
||||
_get_fp32_precision_getter,
|
||||
_set_fp32_precision_setter,
|
||||
ContextProp,
|
||||
PropModule,
|
||||
)
|
||||
|
||||
|
||||
def is_available():
|
||||
@ -64,11 +71,14 @@ class verbose:
|
||||
return False
|
||||
|
||||
|
||||
def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
|
||||
def set_flags(
|
||||
_enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none"
|
||||
):
|
||||
orig_flags = (
|
||||
torch._C._get_mkldnn_enabled(),
|
||||
torch._C._get_mkldnn_deterministic(),
|
||||
torch._C._get_onednn_allow_tf32(),
|
||||
torch._C._get_fp32_precision_getter("mkldnn", "all"),
|
||||
)
|
||||
if _enabled is not None:
|
||||
torch._C._set_mkldnn_enabled(_enabled)
|
||||
@ -76,13 +86,15 @@ def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
|
||||
torch._C._set_mkldnn_deterministic(_deterministic)
|
||||
if _allow_tf32 is not None:
|
||||
torch._C._set_onednn_allow_tf32(_allow_tf32)
|
||||
if _fp32_precision is not None:
|
||||
torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision)
|
||||
return orig_flags
|
||||
|
||||
|
||||
@contextmanager
|
||||
def flags(enabled=False, deterministic=False, allow_tf32=True):
|
||||
def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"):
|
||||
with __allow_nonbracketed_mutation():
|
||||
orig_flags = set_flags(enabled, deterministic, allow_tf32)
|
||||
orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -104,6 +116,13 @@ class MkldnnModule(PropModule):
|
||||
allow_tf32 = ContextProp(
|
||||
torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
|
||||
)
|
||||
matmul = _FP32Precision("mkldnn", "matmul")
|
||||
conv = _FP32Precision("mkldnn", "conv")
|
||||
rnn = _FP32Precision("mkldnn", "rnn")
|
||||
fp32_precision = ContextProp(
|
||||
_get_fp32_precision_getter("mkldnn", "all"),
|
||||
_set_fp32_precision_setter("generic", "all"),
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -670,10 +670,12 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
|
||||
}
|
||||
|
||||
static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::globalContext().allowTF32CuDNN())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setFloat32MatmulPrecision(
|
||||
@ -694,6 +696,7 @@ static PyObject* THPModule_setFloat32MatmulPrecision(
|
||||
static PyObject* THPModule_float32MatmulPrecision(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
std::string s = "highest";
|
||||
auto p = at::globalContext().float32MatmulPrecision();
|
||||
if (p == at::Float32MatmulPrecision::HIGH) {
|
||||
@ -702,6 +705,7 @@ static PyObject* THPModule_float32MatmulPrecision(
|
||||
s = "medium";
|
||||
}
|
||||
return THPUtils_packString(s);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
static PyObject* THPModule_setSDPPriorityOrder(
|
||||
PyObject* _unused,
|
||||
@ -1116,10 +1120,12 @@ static PyObject* THPModule_setAllowTF32CuBLAS(
|
||||
static PyObject* THPModule_allowTF32CuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
|
||||
@ -2360,6 +2366,18 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_get_fp32_precision_getter", [](std::string backend, std::string op) {
|
||||
return at::globalContext().float32Precision(backend, op);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_set_fp32_precision_setter",
|
||||
[](std::string backend, std::string op, std::string precision) {
|
||||
at::globalContext().setFloat32Precision(backend, op, precision);
|
||||
return precision;
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
|
||||
at::impl::ThreadLocalPythonObjects::get_state().set(
|
||||
|
@ -590,7 +590,7 @@ struct GlobalStateGuard {
|
||||
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
|
||||
_deterministic_algorithms = ctx.deterministicAlgorithms();
|
||||
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
|
||||
_allow_tf32 = ctx.allowTF32CuBLAS();
|
||||
_allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32";
|
||||
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
|
||||
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
|
||||
_num_threads = at::get_num_threads();
|
||||
@ -607,7 +607,7 @@ struct GlobalStateGuard {
|
||||
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
|
||||
_deterministic_algorithms_warn_only ==
|
||||
ctx.deterministicAlgorithmsWarnOnly() &&
|
||||
_allow_tf32 == ctx.allowTF32CuBLAS() &&
|
||||
_allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") &&
|
||||
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
|
||||
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
|
||||
_num_threads == at::get_num_threads()) &&
|
||||
@ -628,7 +628,7 @@ struct GlobalStateGuard {
|
||||
if (_deterministic_algorithms_warn_only !=
|
||||
ctx.deterministicAlgorithmsWarnOnly())
|
||||
os << "deterministic_algorithms_warn_only ";
|
||||
if (_allow_tf32 != ctx.allowTF32CuBLAS())
|
||||
if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32"))
|
||||
os << "allow_tf32 ";
|
||||
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
|
||||
os << "allow_fp16_reduce ";
|
||||
|
@ -396,7 +396,8 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
||||
}
|
||||
|
||||
event->start_time_ = c10::getApproximateTime();
|
||||
event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS();
|
||||
event->allow_tf32_cublas_ =
|
||||
at::globalContext().float32Precision("cuda", "matmul") == "tf32";
|
||||
if (!config_.experimental_config.performance_events.empty()) {
|
||||
const size_t n = config_.experimental_config.performance_events.size();
|
||||
event->counters_ = std::make_unique<perf_counters_t>(n, 0);
|
||||
|
@ -5728,5 +5728,25 @@ def scoped_load_inline(func):
|
||||
return cpp_extension.load_inline(*args, **kwargs)
|
||||
|
||||
return func(*args, load_inline=load_inline, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def recover_orig_fp32_precision(fn):
|
||||
@contextlib.contextmanager
|
||||
def recover():
|
||||
old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
|
||||
|
||||
return recover()(fn)
|
||||
|
Reference in New Issue
Block a user