mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it. ### Design Choice: Directly use algorithms name like "TF32", "BF16". #### Pros - The names are more informative. 'tf32' is more informative than a simple "high". - Easier to extend new algorithm like `tf32x3` #### Cons - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them. ### We provide a layered structure for backends/operators. ('f32' is short for 'fp32_precision')  ### We provide 3 fp32 compute precision can be set: - **"ieee"**: Not allowed to use any other internal computation data types . - **"tf32"**: Allowed to use tf32 as internal computation data types. - **"bf16"**: Allowed to use bf16 as internal computation data types. - **"none"**: Precision's are not set. Can be override by its father node. ### Overriding Precision Settings Child node can be override by its father node if it is set to default. For current default settings: ``` backend = generic, op = all, precision setting = none backend = cuda, op = all, precision setting = none backend = cuda, op = conv, precision setting = tf32 backend = cuda, op = rnn, precision setting = tf32 backend = cuda, op = matmul, precision setting = none backend = matmul, op = all, precision setting = none backend = matmul, op = conv, precision setting = none backend = matmul, op = rnn, precision setting = none backend = matmul, op = matmul, precision setting = none ``` - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16". - If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16". ### Backward Compatible Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is - If the user only uses previous APIs, it will work as previous expectations. - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user. ### Test Plan ``` python test/test_cuda.py -k test_fp32_precision_with_tf32 python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision python test/test_cuda.py -k test_invalid_status_for_legacy_api python test/test_mkldnn.py -k test_mlkdnn_get_set python test/test_mkldnn.py -k test_generic_precision python test/test_mkldnn.py -k test_invalid python test/test_mkldnn.py -k test_default_use_parent ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888 Approved by: https://github.com/jgong5, https://github.com/albanD Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
b5f1345f72
commit
4c11b26158
@ -5635,5 +5635,25 @@ def scoped_load_inline(func):
|
||||
return cpp_extension.load_inline(*args, **kwargs)
|
||||
|
||||
return func(*args, load_inline=load_inline, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def recover_orig_fp32_precision(fn):
|
||||
@contextlib.contextmanager
|
||||
def recover():
|
||||
old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
|
||||
|
||||
return recover()(fn)
|
||||
|
Reference in New Issue
Block a user