mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add `strict=True/False` to zip calls in test utils. `strict=True` is passed when possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165531 Approved by: https://github.com/Skylion007
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
# mypy: ignore-errors
|
|
|
|
import contextlib
|
|
import functools
|
|
import inspect
|
|
|
|
import torch
|
|
|
|
|
|
def bf32_is_not_fp32():
|
|
if not torch.backends.mkldnn.is_available():
|
|
return False
|
|
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
|
return False
|
|
return True
|
|
|
|
|
|
def tf32_is_not_fp32():
|
|
if not torch.backends.mkldnn.is_available():
|
|
return False
|
|
if not torch._C._cpu._is_amx_fp16_supported():
|
|
return False
|
|
return True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def reduced_f32_off():
|
|
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
|
|
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
|
|
try:
|
|
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
|
|
torch.backends.mkldnn.conv.fp32_precision = "ieee"
|
|
yield
|
|
finally:
|
|
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
|
|
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def bf32_on(self, bf32_precision=1e-2):
|
|
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
|
|
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
|
|
old_precision = self.precision
|
|
try:
|
|
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
|
|
torch.backends.mkldnn.conv.fp32_precision = "bf16"
|
|
self.precision = bf32_precision
|
|
yield
|
|
finally:
|
|
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
|
|
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
|
|
self.precision = old_precision
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_on(self, tf32_precision=1e-5):
|
|
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
|
|
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
|
|
old_precision = self.precision
|
|
try:
|
|
torch.backends.mkldnn.matmul.fp32_precision = "tf32"
|
|
torch.backends.mkldnn.conv.fp32_precision = "tf32"
|
|
self.precision = tf32_precision
|
|
yield
|
|
finally:
|
|
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
|
|
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
|
|
self.precision = old_precision
|
|
|
|
|
|
# This is a wrapper that wraps a test to run this test three times, one with
|
|
# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32
|
|
# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/
|
|
# tf32) as specified by the argument.
|
|
def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5):
|
|
def with_reduced_f32_disabled(self, function_call):
|
|
with reduced_f32_off():
|
|
function_call()
|
|
|
|
def with_bf32_enabled(self, function_call):
|
|
with bf32_on(self, bf32_precision):
|
|
function_call()
|
|
|
|
def with_tf32_enabled(self, function_call):
|
|
with tf32_on(self, tf32_precision):
|
|
function_call()
|
|
|
|
def wrapper(f):
|
|
params = inspect.signature(f).parameters
|
|
arg_names = tuple(params.keys())
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
kwargs.update(zip(arg_names, args, strict=False))
|
|
cond = True
|
|
if "device" in kwargs:
|
|
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
|
|
if "dtype" in kwargs:
|
|
cond = cond and (kwargs["dtype"] == torch.float)
|
|
bf32_cond = cond and bf32_is_not_fp32()
|
|
tf32_cond = cond and tf32_is_not_fp32()
|
|
if bf32_cond or tf32_cond:
|
|
with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs))
|
|
if bf32_cond:
|
|
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
|
|
if tf32_cond:
|
|
with_tf32_enabled(kwargs["self"], lambda: f(**kwargs))
|
|
else:
|
|
f(**kwargs)
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|