Add BF16 type to _autocast_to_full_precision (#67707)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67707

https://github.com/pytorch/pytorch/pull/63939/files has added FP16 support to torchscript.

This is to add BF16 device type when doing full conversion.

Test Plan: Unit test. Also tested BF16 locally on A100 using MLP model.

Reviewed By: idning

Differential Revision: D32027152

fbshipit-source-id: b2a5ff2b22ea1e02306b0399f2b39b8493be4f45
This commit is contained in:
Yusuo Hu
2021-11-03 14:05:23 -07:00
committed by Facebook GitHub Bot
parent 05e17e7ff6
commit fddfb81dd0
2 changed files with 19 additions and 4 deletions

View File

@ -32,9 +32,24 @@ class TestAutocast(JitTestCase):
@torch.jit.script
def fn(a, b):
with autocast():
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.float16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_linear_bf16(self):
@torch.jit.script
def fn(a, b):
with autocast(dtype=torch.bfloat16):
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.bfloat16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal_cpu(self):