fix cpu autocast check in rnn (#100621)

https://github.com/pytorch/pytorch/pull/100100 added Typechecking while `torch.is_autocast_enabled()` always return `False` on cpu. This PR fixes the autocast check for cpu.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100621
Approved by: https://github.com/albanD
This commit is contained in:
chunyuan
2023-05-08 15:48:29 +00:00
committed by PyTorch MergeBot
parent 26cd958718
commit 7012600abe
3 changed files with 23 additions and 4 deletions

View File

@ -4,7 +4,7 @@ import collections
import unittest
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.utils._python_dispatch import TorchDispatchMode
@ -125,6 +125,23 @@ class TestAutocastCPU(TestCase):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_autocast_rnn(self):
if torch._C.has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
x = torch.randn(1, 2, 1)
hx = torch.randn(2, 2, 1)
cx = torch.randn(2, 2, 1)
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
# Raise ValueError when autocast is not enabled
with self.assertRaisesRegex(ValueError, "input must have the type"):
m(x, (hx, cx))
# Should be able to run the below case with autocast
with torch.cpu.amp.autocast():
m(x, (hx, cx))
class CustomLinear(torch.autograd.Function):
@staticmethod

View File

@ -1134,6 +1134,7 @@ def is_autocast_enabled() -> _bool: ...
def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...
def _is_any_autocast_enabled() -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ...

View File

@ -209,9 +209,10 @@ class RNNBase(Module):
init.uniform_(weight, -stdv, stdv)
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
if input.dtype != self._flat_weights[0].dtype and not torch.is_autocast_enabled():
raise ValueError('input must have the type {}, got type {}'.format(
self._flat_weights[0].dtype, input.dtype))
if not torch.jit.is_scripting():
if input.dtype != self._flat_weights[0].dtype and not torch._C._is_any_autocast_enabled():
raise ValueError('input must have the type {}, got type {}'.format(
self._flat_weights[0].dtype, input.dtype))
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(