mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix cpu autocast check in rnn (#100621)
https://github.com/pytorch/pytorch/pull/100100 added Typechecking while `torch.is_autocast_enabled()` always return `False` on cpu. This PR fixes the autocast check for cpu. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100621 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
26cd958718
commit
7012600abe
@ -4,7 +4,7 @@ import collections
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
@ -125,6 +125,23 @@ class TestAutocastCPU(TestCase):
|
||||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def test_autocast_rnn(self):
|
||||
if torch._C.has_mkldnn and torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
x = torch.randn(1, 2, 1)
|
||||
hx = torch.randn(2, 2, 1)
|
||||
cx = torch.randn(2, 2, 1)
|
||||
|
||||
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
|
||||
|
||||
# Raise ValueError when autocast is not enabled
|
||||
with self.assertRaisesRegex(ValueError, "input must have the type"):
|
||||
m(x, (hx, cx))
|
||||
|
||||
# Should be able to run the below case with autocast
|
||||
with torch.cpu.amp.autocast():
|
||||
m(x, (hx, cx))
|
||||
|
||||
|
||||
class CustomLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
||||
@ -1134,6 +1134,7 @@ def is_autocast_enabled() -> _bool: ...
|
||||
def clear_autocast_cache() -> None: ...
|
||||
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
||||
def is_autocast_cpu_enabled() -> _bool: ...
|
||||
def _is_any_autocast_enabled() -> _bool: ...
|
||||
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
|
||||
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
|
||||
def get_autocast_cpu_dtype() -> _dtype: ...
|
||||
|
||||
@ -209,9 +209,10 @@ class RNNBase(Module):
|
||||
init.uniform_(weight, -stdv, stdv)
|
||||
|
||||
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
||||
if input.dtype != self._flat_weights[0].dtype and not torch.is_autocast_enabled():
|
||||
raise ValueError('input must have the type {}, got type {}'.format(
|
||||
self._flat_weights[0].dtype, input.dtype))
|
||||
if not torch.jit.is_scripting():
|
||||
if input.dtype != self._flat_weights[0].dtype and not torch._C._is_any_autocast_enabled():
|
||||
raise ValueError('input must have the type {}, got type {}'.format(
|
||||
self._flat_weights[0].dtype, input.dtype))
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
if input.dim() != expected_input_dim:
|
||||
raise RuntimeError(
|
||||
|
||||
Reference in New Issue
Block a user