mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it. ### Design Choice: Directly use algorithms name like "TF32", "BF16". #### Pros - The names are more informative. 'tf32' is more informative than a simple "high". - Easier to extend new algorithm like `tf32x3` #### Cons - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them. ### We provide a layered structure for backends/operators. ('f32' is short for 'fp32_precision')  ### We provide 3 fp32 compute precision can be set: - **"ieee"**: Not allowed to use any other internal computation data types . - **"tf32"**: Allowed to use tf32 as internal computation data types. - **"bf16"**: Allowed to use bf16 as internal computation data types. - **"none"**: Precision's are not set. Can be override by its father node. ### Overriding Precision Settings Child node can be override by its father node if it is set to default. For current default settings: ``` backend = generic, op = all, precision setting = none backend = cuda, op = all, precision setting = none backend = cuda, op = conv, precision setting = tf32 backend = cuda, op = rnn, precision setting = tf32 backend = cuda, op = matmul, precision setting = none backend = matmul, op = all, precision setting = none backend = matmul, op = conv, precision setting = none backend = matmul, op = rnn, precision setting = none backend = matmul, op = matmul, precision setting = none ``` - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16". - If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16". ### Backward Compatible Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is - If the user only uses previous APIs, it will work as previous expectations. - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user. ### Test Plan ``` python test/test_cuda.py -k test_fp32_precision_with_tf32 python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision python test/test_cuda.py -k test_invalid_status_for_legacy_api python test/test_mkldnn.py -k test_mlkdnn_get_set python test/test_mkldnn.py -k test_generic_precision python test/test_mkldnn.py -k test_invalid python test/test_mkldnn.py -k test_default_use_parent ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888 Approved by: https://github.com/jgong5, https://github.com/albanD Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b5f1345f72
commit
4c11b26158
@ -667,10 +667,12 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
|
||||
}
|
||||
|
||||
static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::globalContext().allowTF32CuDNN())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setFloat32MatmulPrecision(
|
||||
@ -691,6 +693,7 @@ static PyObject* THPModule_setFloat32MatmulPrecision(
|
||||
static PyObject* THPModule_float32MatmulPrecision(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
std::string s = "highest";
|
||||
auto p = at::globalContext().float32MatmulPrecision();
|
||||
if (p == at::Float32MatmulPrecision::HIGH) {
|
||||
@ -699,6 +702,7 @@ static PyObject* THPModule_float32MatmulPrecision(
|
||||
s = "medium";
|
||||
}
|
||||
return THPUtils_packString(s);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
static PyObject* THPModule_setSDPPriorityOrder(
|
||||
PyObject* _unused,
|
||||
@ -1113,10 +1117,12 @@ static PyObject* THPModule_setAllowTF32CuBLAS(
|
||||
static PyObject* THPModule_allowTF32CuBLAS(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
Py_RETURN_FALSE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
|
||||
@ -2287,6 +2293,18 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_get_fp32_precision_getter", [](std::string backend, std::string op) {
|
||||
return at::globalContext().float32Precision(backend, op);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_set_fp32_precision_setter",
|
||||
[](std::string backend, std::string op, std::string precision) {
|
||||
at::globalContext().setFloat32Precision(backend, op, precision);
|
||||
return precision;
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
|
||||
at::impl::ThreadLocalPythonObjects::get_state().set(
|
||||
|
Reference in New Issue
Block a user