mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "refine fp32 precision api (#125888)"
This reverts commit 4c11b26158691cfd9ad48338ddebd1ca9bded788. Reverted https://github.com/pytorch/pytorch/pull/125888 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some failures on ROCm ([comment](https://github.com/pytorch/pytorch/pull/125888#issuecomment-2869274791))
This commit is contained in:
@ -5635,25 +5635,5 @@ def scoped_load_inline(func):
|
||||
return cpp_extension.load_inline(*args, **kwargs)
|
||||
|
||||
return func(*args, load_inline=load_inline, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def recover_orig_fp32_precision(fn):
|
||||
@contextlib.contextmanager
|
||||
def recover():
|
||||
old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined]
|
||||
old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined]
|
||||
old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined]
|
||||
torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined]
|
||||
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
|
||||
|
||||
return recover()(fn)
|
||||
|
Reference in New Issue
Block a user