refine fp32 precision api (#125888)

Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16".
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "high".
 - Easier to extend new algorithm like `tf32x3`
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/user-attachments/assets/f89143e5-d6a1-4865-9351-9a50439f5067)

### We provide 3 fp32 compute precision can be set:
 - **"ieee"**: Not allowed to use any other internal computation data types .
 - **"tf32"**: Allowed to use tf32 as internal computation data types.
 - **"bf16"**: Allowed to use bf16 as internal computation data types.
 - **"none"**:  Precision's are not set. Can be override by its father node.

### Overriding Precision Settings
Child node can be override by its father node if it is set to default.
For current default settings:
```
backend = generic, op = all, precision setting = none
    backend = cuda, op = all, precision setting = none
        backend = cuda, op = conv, precision setting = tf32
        backend = cuda, op = rnn, precision setting = tf32
        backend = cuda, op = matmul, precision setting = none
    backend = matmul, op = all, precision setting = none
        backend = matmul, op = conv, precision setting = none
        backend = matmul, op = rnn, precision setting = none
        backend = matmul, op = matmul, precision setting = none
```
 - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16".
 - If the user set `torch.backends.fp32_precision="bf16"`,  `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16".

### Backward Compatible
Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is
 - If the user only uses previous APIs, it will work as previous expectations.
 - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user.

### Test Plan
```
python test/test_cuda.py -k test_fp32_precision_with_tf32
python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision
python test/test_cuda.py -k test_invalid_status_for_legacy_api
python test/test_mkldnn.py -k test_mlkdnn_get_set
python test/test_mkldnn.py -k test_generic_precision
python test/test_mkldnn.py -k test_invalid
python test/test_mkldnn.py -k test_default_use_parent
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888
Approved by: https://github.com/jgong5, https://github.com/albanD

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
haozhe.zhu
2025-06-26 02:59:26 +00:00
committed by PyTorch MergeBot
parent de45c5f673
commit 53e0b9c393
27 changed files with 607 additions and 40 deletions

View File

@ -19,9 +19,69 @@
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <cpuinfo.h>
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means
IEEE standard floating point format "tf32" and "bf16" means we are allowed to
use "tf32" or "bf16" as internal computation data types for fp32 computations.
And "none" means it is override-able by parent's node
generic->mkldnn->matmul
->conv
->rnn
->cuda ->matmul
->conv
->rnn
*/
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
{"mkldnn", {{"ieee", "bf16", "none"}}},
{"cuda", {{"ieee", "tf32", "none"}}}};
// Check whether the backend and op are legal
void check_fp32_prec_backend_and_op(
const std::string& backend,
const std::string& op) {
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
TORCH_CHECK(
std::find(backends.begin(), backends.end(), backend) != backends.end(),
"Invalid backend: ",
backend);
TORCH_CHECK(
std::find(operators.begin(), operators.end(), op) != operators.end(),
"Invalid operator: ",
op);
if (backend == "generic") {
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
}
}
// Return whether the precision is supported by backends
bool validate_fp32_prec(
const std::string& backend,
const std::string& precision) {
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
auto precisions = iterp->second;
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
precisions.end();
return valid;
}
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"This API is going to be deprecated, please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Context::Context() = default;
// TODO: This could be bad juju if someone calls globalContext() in the
@ -115,12 +175,29 @@ void Context::setUserEnabledNNPACK(bool e) {
enabled_nnpack = e;
}
bool Context::allowTF32CuDNN() const {
bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.size() == 0){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
"but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.",
"This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ",
"We suggest only using the new API to set the TF32 flag(s). See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
} else {
return float32Precision("cuda", op) == "tf32";
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -259,7 +336,16 @@ bool Context::allowTF32CuBLAS() const {
return false;
}
#endif
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
legacy_allow_tf32 == allow_tf32_new,
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}
void Context::setAllowTF32CuBLAS(bool b) {
@ -272,27 +358,54 @@ void Context::setAllowTF32CuBLAS(bool b) {
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}
Float32MatmulPrecision Context::float32MatmulPrecision() const {
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
invalid = invalid ||
(float32Precision("mkldnn", "matmul") == "bf16" &&
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
TORCH_CHECK(
!invalid,
"PyTorch is checking the matmul precision without a specific backend name,",
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
float32_matmul_precision = p;
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
check_fp32_prec_backend_and_op(backend, op);
auto precision = fp32_precision.find(backend)->second.find(op)->second;
if (precision == "none")
precision = fp32_precision.find(backend)->second.find("all")->second;
if (precision == "none")
precision = fp32_precision.find("generic")->second.find("all")->second;
bool valid_prec = validate_fp32_prec(backend, precision);
return valid_prec ? precision : "none";
}
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", "ieee");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "high") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "medium") {
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "bf16");
return true;
}
return false;
@ -306,6 +419,27 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
"setFloat32MatmulPrecision call has no effect.");
}
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
check_fp32_prec_backend_and_op(backend, op);
if (validate_fp32_prec(backend, p)) {
fp32_precision[backend][op] = p;
} else {
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (auto p : iterp->second) {
msg += p;
msg += " ";
}
TORCH_WARN(
"you have set wrong precision for backend:",
backend,
" setFloat32Precision call has no effect.",
"Please choose precision from: ",
msg);
}
}
at::LinalgBackend Context::linalgPreferredBackend() const {
return linalg_preferred_backend;
}

View File

@ -28,6 +28,7 @@
#include <c10/util/irange.h>
#include <cstdint>
#include <map>
#include <mutex>
namespace at {
@ -336,14 +337,20 @@ class TORCH_API Context {
void alertCuBLASConfigNotDeterministic() const;
void setFloat32MatmulPrecision(const std::string& s);
bool allowTF32CuDNN() const;
void setFloat32Precision(
const std::string& backend,
const std::string& op,
const std::string& s);
bool allowTF32CuDNN(const std::string& op = std::string()) const;
void setAllowTF32CuDNN(bool);
bool allowTF32OneDNN() const;
void setAllowTF32OneDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
std::string float32Precision(
const std::string& backend,
const std::string& op) const;
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
@ -469,6 +476,23 @@ class TORCH_API Context {
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
{"generic", {{"all", "none"}}},
{"mkldnn",
{{"matmul", "none"},
{"conv", "none"},
{"rnn", "none"},
{"all", "none"}}},
{"cuda",
{{"matmul",
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
? "none"
: "tf32"},
{"conv", "tf32"},
{"rnn", "tf32"},
{"all", "none"}}},
};
Allocator* prev_allocator_ptr_{nullptr};
};

View File

@ -407,7 +407,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
@ -1589,7 +1589,7 @@ bool gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, at::Half>) {

View File

@ -218,7 +218,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
if (!NoTF32Guard::should_disable_tf32() &&
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));

View File

@ -160,7 +160,7 @@ inline std::string ComputeTypeFor() {
// ROCBLAS and hipBLASLt.
template <>
inline std::string ComputeTypeFor<float>() {
if (!at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
return "f32_r";
} else {
return "xf32_r";

View File

@ -499,7 +499,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
}
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);

View File

@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);

View File

@ -1174,7 +1174,7 @@ at::Tensor convolution(
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
}
at::Tensor convolution_overrideable(
@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
auto input = input_r;
auto weight = weight_r;
@ -1705,7 +1705,7 @@ at::Tensor _convolution(
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
}
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
@ -2003,7 +2003,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
// Validate inputs.
check_shape_backward(input, weight.sizes(), params);

View File

@ -169,7 +169,8 @@ std::string repro_from_args(const ConvolutionParams& params) {
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
ss << "import torch\n";
ss << "torch.backends.cuda.matmul.allow_tf32 = "
<< pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
<< "\n";
ss << "torch.backends.cudnn.benchmark = "
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
@ -725,7 +726,7 @@ Tensor cudnn_convolution_relu(
auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
@ -783,7 +784,7 @@ Tensor cudnn_convolution_add_relu(
}
auto& ctx = at::globalContext();
bool allow_tf32 = ctx.allowTF32CuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()

View File

@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN());
at::globalContext().allowTF32CuDNN("rnn"));
#else
rnn_desc.set(
handle,
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN());
at::globalContext().allowTF32CuDNN("rnn"));
#endif
return rnn_desc;
}

View File

@ -104,7 +104,7 @@ static bool use_mkldnn_fp16_matmul() {
}
static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
}
// returns an ideep::tensor

View File

@ -133,6 +133,44 @@ To toggle the TF32 flags off in C++, you can do
at::globalContext().setAllowTF32CuBLAS(false);
at::globalContext().setAllowTF32CuDNN(false);
After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way.
We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator.
.. code:: python
torch.backends.fp32_precision = "ieee"
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`.
`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision.
`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision.
We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.cudnn.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision`
is overridden to `ieee`.
Old settings are still supported. But we suggest to use the new settings for better control. And we do not support
to use mix of old and new settings.
For more information about TF32, see:
- `TensorFloat-32`_

View File

@ -0,0 +1,102 @@
.. meta::
:description: A guide to torch.backends.mkldnn, a PyTorch backend to run MKLDNN operations
:keywords: optimize PyTorch, MKLDNN
.. _mkldnn_backend:
MKLDNN backend
---------------------------------------------------
MKLDNN is an open-source cross-platform performance library of basic building blocks
for deep learning applications.
.. code:: python
# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True
Users can disable MKLDNN backend by:
.. code:: python
torch.backends.mkldnn.enabled = False
.. _bf16_on_mkldnn:
Bfloat16 (BF16) on MKLDNN backend
---------------------------------------------------
Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision
for `float32` operators.
.. code:: python
# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"
Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses
matmuls or convolutions are also affected. These include :class:`torch.nn.Linear`, :class:`torch.nn._ConvNd`, :func:`torch.cdist`,
:func:`torch.tensordot`, :func:`torch.nn.functional.affine_grid` and :func:`torch.nn.functional.grid_sample`,
:class:`torch.nn.AdaptiveLogSoftmaxWithLoss`, :class:`torch.nn.GRU` and :class:`torch.nn.LSTM`.
To get an idea of the precision and speed, see the example code and benchmark data (on SPR) below:
.. code:: python
torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7451
a = a_full.float()
b = b_full.float()
# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max() # 1.3704
relative_error = error / mean # 0.0170
print(error, relative_error)
# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max() # 0.0003
relative_error = error / mean # 0.00000317
print(error, relative_error)
From the above example, we can see that with BF16, the speed is ~7x faster on SPR, and that
relative error compared to double precision is approximately 2 orders of magnitude larger.
If full FP32 precision is needed, users can disable BF16 by:
.. code:: python
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'
To toggle the BF16 flags off in C++, you can do
.. code:: C++
at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");
We can override a generic setting for a specific operator or backend if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
For such case, both `torch.backends.mkldnn.fp32_precision` and `torch.backends.mkldnn.matmul.fp32_precision`
is overridden to bf16.

View File

@ -30,7 +30,13 @@ from torch.testing._internal.common_device_type import (
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
torch.set_float32_matmul_precision("high")
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:
torch.set_float32_matmul_precision("high")
index = torch.ops.aten.index
Tensor = torch.Tensor

View File

@ -97,7 +97,13 @@ class TestCaseBase(TestCase):
if HAS_GPU:
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
cls.prior_default_device = torch.get_default_device()
torch.set_float32_matmul_precision("high")
# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:
torch.set_float32_matmul_precision("high")
torch.set_default_device(GPU_TYPE)
@classmethod

View File

@ -69,6 +69,7 @@ from torch.testing._internal.common_utils import (
load_tests,
MI300_ARCH,
parametrize,
recover_orig_fp32_precision,
run_tests,
serialTest,
setBlasBackendsToDefaultFinally,
@ -849,6 +850,55 @@ print(t.is_pinned())
):
self.assertTrue(torch.backends.cudnn.allow_tf32)
@recover_orig_fp32_precision
def test_fp32_precision_with_tf32(self):
with torch.backends.cudnn.flags(
enabled=None,
benchmark=None,
benchmark_limit=None,
deterministic=None,
allow_tf32=True,
fp32_precision="none",
):
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "tf32")
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "tf32")
with torch.backends.cudnn.flags(
enabled=None,
benchmark=None,
benchmark_limit=None,
deterministic=None,
allow_tf32=False,
fp32_precision="none",
):
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "none")
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "none")
@recover_orig_fp32_precision
def test_fp32_precision_with_float32_matmul_precision(self):
torch.set_float32_matmul_precision("highest")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "ieee")
torch.set_float32_matmul_precision("high")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
torch.set_float32_matmul_precision("medium")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
@recover_orig_fp32_precision
def test_invalid_status_for_legacy_api(self):
torch.backends.cudnn.conv.fp32_precision = "none"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.backends.cudnn.allow_tf32)
torch.set_float32_matmul_precision("highest")
torch.backends.cuda.matmul.fp32_precision = "tf32"
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.get_float32_matmul_precision())
if not TEST_WITH_ROCM:
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.backends.cuda.matmul.allow_tf32)
def test_type_conversions(self):
x = torch.randn(5, 5)
self.assertIsInstance(x.float(), torch.FloatTensor)

View File

@ -22,7 +22,7 @@ import torch.backends.mkldnn
from torch.utils import mkldnn as mkldnn_utils
from torch.testing._internal.common_utils import TestCase, \
run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
skipIfTorchDynamo, xfailIfTorchDynamo
skipIfTorchDynamo, xfailIfTorchDynamo, recover_orig_fp32_precision
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
@ -1659,6 +1659,53 @@ class TestMkldnn(TestCase):
self.assertEqual(out_emulated.float(), out.float(), atol=5e-2, rtol=5e-2)
@recover_orig_fp32_precision
def test_mlkdnn_get_set(self):
# get/set mkldnn ops
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16")
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
self.assertEqual(torch.backends.mkldnn.fp32_precision, "none")
# get/set matmul
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
torch.backends.mkldnn.matmul.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
# get/set conv
torch.backends.mkldnn.conv.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16")
torch.backends.mkldnn.conv.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none")
# get/set rnn
torch.backends.mkldnn.rnn.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16")
torch.backends.mkldnn.rnn.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none")
@recover_orig_fp32_precision
def test_generic_precision(self):
with torch.backends.flags(fp32_precision="none"):
self.assertEqual(torch.backends.fp32_precision, "none")
with torch.backends.flags(fp32_precision="tf32"):
self.assertEqual(torch.backends.fp32_precision, "tf32")
@recover_orig_fp32_precision
def test_default_use_parent(self):
torch.backends.mkldnn.matmul.fp32_precision = "none"
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
with torch.backends.flags(fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
with torch.backends.flags(fp32_precision="tf32"):
# when parent is a not supported precision, use default
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
@recover_orig_fp32_precision
def test_invalid(self):
# use default if user set a not supported precision
torch.backends.mkldnn.matmul.fp32_precision = "tf32"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))

View File

@ -1360,6 +1360,8 @@ def _disabled_torch_dispatch_impl(
) -> Any: ... # THPModule_disable_dispatch_function
def _get_linalg_preferred_backend() -> _LinalgBackend: ...
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
class _LinalgBackend:
Default: _LinalgBackend

View File

@ -254,7 +254,9 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
cuda_rng_state = None
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
allow_tf32 = torch._C._get_cublas_allow_tf32()
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
"cuda", "matmul"
)
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
@ -286,7 +288,9 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch._C._unset_default_mobile_cpu_allocator()
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
torch._C._set_cublas_allow_tf32(allow_tf32)
torch._C._set_fp32_precision_setter(
"cuda", "matmul", cuda_matmul_fp32_prec
)
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
assert guards.check(), (
f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"

View File

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs
import sys
import types
from contextlib import contextmanager
import torch
# The idea for this parameter is that we forbid bare assignment
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
@ -57,6 +60,70 @@ class PropModule(types.ModuleType):
return self.m.__getattribute__(attr)
class _FP32Precision:
def __init__(self, backend, op):
self.backend = backend
self.op = op
def __setattr__(self, name, value):
if name == "fp32_precision":
torch._C._set_fp32_precision_setter(self.backend, self.op, value)
elif name in ("backend", "op"):
super().__setattr__(name, value)
else:
raise AttributeError("Unknown attribute " + name)
def __getattr__(self, name):
if name == "fp32_precision":
return torch._C._get_fp32_precision_getter(self.backend, self.op)
else:
raise AttributeError("Unknown attribute " + name)
def set_flags(_fp32_precision="none"):
orig_flags = (torch._C._get_fp32_precision_getter("generic", "all"),)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("generic", "all", _fp32_precision)
return orig_flags
@contextmanager
def flags(fp32_precision="none"):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(fp32_precision)
try:
yield
finally:
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
def _get_fp32_precision_getter(backend, op):
def inner():
return torch._C._get_fp32_precision_getter(backend, op)
return inner
def _set_fp32_precision_setter(backend, op):
def inner(precision):
return torch._C._set_fp32_precision_setter(backend, op, precision)
return inner
class GenericModule(PropModule):
def __init__(self, m, name):
super().__init__(m, name)
fp32_precision = ContextProp(
_get_fp32_precision_getter("generic", "all"),
_set_fp32_precision_setter("generic", "all"),
)
sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__)
from torch.backends import (
cpu as cpu,
cuda as cuda,

View File

@ -135,6 +135,8 @@ class cuBLASModule:
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
elif name == "allow_fp16_accumulation":
return torch._C._get_cublas_allow_fp16_accumulation()
elif name == "fp32_precision":
return torch._C._get_fp32_precision_getter("cuda", "matmul")
raise AttributeError("Unknown attribute " + name)
def __setattr__(self, name, value):
@ -146,6 +148,8 @@ class cuBLASModule:
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
elif name == "allow_fp16_accumulation":
return torch._C._set_cublas_allow_fp16_accumulation(value)
elif name == "fp32_precision":
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
raise AttributeError("Unknown attribute " + name)

View File

@ -6,7 +6,14 @@ from contextlib import contextmanager
from typing import Optional
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
from torch.backends import (
__allow_nonbracketed_mutation,
_FP32Precision,
_get_fp32_precision_getter,
_set_fp32_precision_setter,
ContextProp,
PropModule,
)
try:
@ -128,6 +135,7 @@ def set_flags(
_benchmark_limit=None,
_deterministic=None,
_allow_tf32=None,
_fp32_precision="none",
):
orig_flags = (
torch._C._get_cudnn_enabled(),
@ -135,6 +143,7 @@ def set_flags(
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
torch._C._get_cudnn_deterministic(),
torch._C._get_cudnn_allow_tf32(),
torch._C._get_fp32_precision_getter("cuda", "all"),
)
if _enabled is not None:
torch._C._set_cudnn_enabled(_enabled)
@ -146,6 +155,8 @@ def set_flags(
torch._C._set_cudnn_deterministic(_deterministic)
if _allow_tf32 is not None:
torch._C._set_cudnn_allow_tf32(_allow_tf32)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision)
return orig_flags
@ -156,10 +167,16 @@ def flags(
benchmark_limit=10,
deterministic=False,
allow_tf32=True,
fp32_precision="none",
):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(
enabled, benchmark, benchmark_limit, deterministic, allow_tf32
enabled,
benchmark,
benchmark_limit,
deterministic,
allow_tf32,
fp32_precision,
)
try:
yield
@ -194,6 +211,12 @@ class CudnnModule(PropModule):
allow_tf32 = ContextProp(
torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
)
conv = _FP32Precision("cuda", "conv")
rnn = _FP32Precision("cuda", "rnn")
fp32_precision = ContextProp(
_get_fp32_precision_getter("cuda", "all"),
_set_fp32_precision_setter("cuda", "all"),
)
# This is the sys.modules replacement trick, see

View File

@ -4,7 +4,14 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
from torch.backends import (
__allow_nonbracketed_mutation,
_FP32Precision,
_get_fp32_precision_getter,
_set_fp32_precision_setter,
ContextProp,
PropModule,
)
def is_available():
@ -64,11 +71,14 @@ class verbose:
return False
def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
def set_flags(
_enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none"
):
orig_flags = (
torch._C._get_mkldnn_enabled(),
torch._C._get_mkldnn_deterministic(),
torch._C._get_onednn_allow_tf32(),
torch._C._get_fp32_precision_getter("mkldnn", "all"),
)
if _enabled is not None:
torch._C._set_mkldnn_enabled(_enabled)
@ -76,13 +86,15 @@ def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
torch._C._set_mkldnn_deterministic(_deterministic)
if _allow_tf32 is not None:
torch._C._set_onednn_allow_tf32(_allow_tf32)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision)
return orig_flags
@contextmanager
def flags(enabled=False, deterministic=False, allow_tf32=True):
def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, deterministic, allow_tf32)
orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision)
try:
yield
finally:
@ -104,6 +116,13 @@ class MkldnnModule(PropModule):
allow_tf32 = ContextProp(
torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
)
matmul = _FP32Precision("mkldnn", "matmul")
conv = _FP32Precision("mkldnn", "conv")
rnn = _FP32Precision("mkldnn", "rnn")
fp32_precision = ContextProp(
_get_fp32_precision_getter("mkldnn", "all"),
_set_fp32_precision_setter("generic", "all"),
)
if TYPE_CHECKING:

View File

@ -670,10 +670,12 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
}
static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::globalContext().allowTF32CuDNN())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setFloat32MatmulPrecision(
@ -694,6 +696,7 @@ static PyObject* THPModule_setFloat32MatmulPrecision(
static PyObject* THPModule_float32MatmulPrecision(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
std::string s = "highest";
auto p = at::globalContext().float32MatmulPrecision();
if (p == at::Float32MatmulPrecision::HIGH) {
@ -702,6 +705,7 @@ static PyObject* THPModule_float32MatmulPrecision(
s = "medium";
}
return THPUtils_packString(s);
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setSDPPriorityOrder(
PyObject* _unused,
@ -1116,10 +1120,12 @@ static PyObject* THPModule_setAllowTF32CuBLAS(
static PyObject* THPModule_allowTF32CuBLAS(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::globalContext().allowTF32CuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
@ -2360,6 +2366,18 @@ Call this whenever a new thread is created in order to propagate values from
at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
});
py_module.def(
"_get_fp32_precision_getter", [](std::string backend, std::string op) {
return at::globalContext().float32Precision(backend, op);
});
py_module.def(
"_set_fp32_precision_setter",
[](std::string backend, std::string op, std::string precision) {
at::globalContext().setFloat32Precision(backend, op, precision);
return precision;
});
py_module.def(
"_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
at::impl::ThreadLocalPythonObjects::get_state().set(

View File

@ -590,7 +590,7 @@ struct GlobalStateGuard {
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
_deterministic_algorithms = ctx.deterministicAlgorithms();
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
_allow_tf32 = ctx.allowTF32CuBLAS();
_allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32";
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
_num_threads = at::get_num_threads();
@ -607,7 +607,7 @@ struct GlobalStateGuard {
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
_deterministic_algorithms_warn_only ==
ctx.deterministicAlgorithmsWarnOnly() &&
_allow_tf32 == ctx.allowTF32CuBLAS() &&
_allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") &&
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
_num_threads == at::get_num_threads()) &&
@ -628,7 +628,7 @@ struct GlobalStateGuard {
if (_deterministic_algorithms_warn_only !=
ctx.deterministicAlgorithmsWarnOnly())
os << "deterministic_algorithms_warn_only ";
if (_allow_tf32 != ctx.allowTF32CuBLAS())
if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32"))
os << "allow_tf32 ";
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
os << "allow_fp16_reduce ";

View File

@ -396,7 +396,8 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
}
event->start_time_ = c10::getApproximateTime();
event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS();
event->allow_tf32_cublas_ =
at::globalContext().float32Precision("cuda", "matmul") == "tf32";
if (!config_.experimental_config.performance_events.empty()) {
const size_t n = config_.experimental_config.performance_events.size();
event->counters_ = std::make_unique<perf_counters_t>(n, 0);

View File

@ -5728,5 +5728,25 @@ def scoped_load_inline(func):
return cpp_extension.load_inline(*args, **kwargs)
return func(*args, load_inline=load_inline, **kwargs)
return wrapper
def recover_orig_fp32_precision(fn):
@contextlib.contextmanager
def recover():
old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined]
old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined]
old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined]
old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined]
old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined]
old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
try:
yield
finally:
torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined]
torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined]
torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined]
torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined]
torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined]
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
return recover()(fn)