fix bfloat16 autocast skip (#67822)

Summary:
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67822

Reviewed By: mruberry

Differential Revision: D32162605

Pulled By: ngimel

fbshipit-source-id: eb5ccf6c441231e572ec93ac8c2638d028abecad
This commit is contained in:
Natalia Gimelshein
2021-11-03 21:01:05 -07:00
committed by Facebook GitHub Bot
parent 2486061c72
commit 99c7a9f09d

View File

@ -10,6 +10,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck
TEST_BFLOAT16 = torch.cuda.is_bf16_supported()
class TestAutocast(JitTestCase):
def setUp(self):
@ -41,7 +42,7 @@ class TestAutocast(JitTestCase):
self.assertEqual(x.dtype, torch.float16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
@unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
def test_linear_bf16(self):
@torch.jit.script
def fn(a, b):