Files
pytorch/torch/testing/_internal/common_mkldnn.py
Jiang, Yanbing f4d8bc46c7 Enable TF32 as fp32 internal precision for matmul/linear/conv (#157520)
### Description

This PR is to enable TF32 as fp32 internal precision for matmul/linear/conv in `mkldnn backend`. Since we have refined fp32 precision API in https://github.com/pytorch/pytorch/pull/125888, we can easily extend the API to support TF32 for `mkldnn backend`.

```
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
torch.backends.mkldnn.conv.fp32_precision = "tf32"
```

Related kernel update and UTs update are done. And the wrapper `bf32_on_and _off` is updated to `reduced_f32_on_and_off`, and it can run tests 3 times, one is reduced_f32 OFF, the other two are reduced_f32 ON (including `bf32 ON` and `tf32 ON`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157520
Approved by: https://github.com/mingfeima, https://github.com/jansel
2025-07-17 08:57:34 +00:00

114 lines
3.8 KiB
Python

# mypy: ignore-errors
import contextlib
import functools
import inspect
import torch
def bf32_is_not_fp32():
if not torch.backends.mkldnn.is_available():
return False
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return False
return True
def tf32_is_not_fp32():
if not torch.backends.mkldnn.is_available():
return False
if not torch._C._cpu._is_amx_fp16_supported():
return False
return True
@contextlib.contextmanager
def reduced_f32_off():
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
try:
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
torch.backends.mkldnn.conv.fp32_precision = "ieee"
yield
finally:
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
@contextlib.contextmanager
def bf32_on(self, bf32_precision=1e-2):
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
old_precision = self.precision
try:
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
torch.backends.mkldnn.conv.fp32_precision = "bf16"
self.precision = bf32_precision
yield
finally:
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
self.precision = old_precision
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
old_precision = self.precision
try:
torch.backends.mkldnn.matmul.fp32_precision = "tf32"
torch.backends.mkldnn.conv.fp32_precision = "tf32"
self.precision = tf32_precision
yield
finally:
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
self.precision = old_precision
# This is a wrapper that wraps a test to run this test three times, one with
# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32
# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/
# tf32) as specified by the argument.
def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5):
def with_reduced_f32_disabled(self, function_call):
with reduced_f32_off():
function_call()
def with_bf32_enabled(self, function_call):
with bf32_on(self, bf32_precision):
function_call()
def with_tf32_enabled(self, function_call):
with tf32_on(self, tf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
kwargs.update(zip(arg_names, args))
cond = True
if "device" in kwargs:
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
if "dtype" in kwargs:
cond = cond and (kwargs["dtype"] == torch.float)
bf32_cond = cond and bf32_is_not_fp32()
tf32_cond = cond and tf32_is_not_fp32()
if bf32_cond or tf32_cond:
with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs))
if bf32_cond:
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
if tf32_cond:
with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper