mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix bfloat16 autocast skip (#67822)
Summary: Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/67822 Reviewed By: mruberry Differential Revision: D32162605 Pulled By: ngimel fbshipit-source-id: eb5ccf6c441231e572ec93ac8c2638d028abecad
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2486061c72
commit
99c7a9f09d
@ -10,6 +10,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing import FileCheck
|
||||
|
||||
TEST_BFLOAT16 = torch.cuda.is_bf16_supported()
|
||||
|
||||
class TestAutocast(JitTestCase):
|
||||
def setUp(self):
|
||||
@ -41,7 +42,7 @@ class TestAutocast(JitTestCase):
|
||||
self.assertEqual(x.dtype, torch.float16)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
@unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
|
||||
def test_linear_bf16(self):
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
|
Reference in New Issue
Block a user