mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add BF16 type to _autocast_to_full_precision (#67707)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67707 https://github.com/pytorch/pytorch/pull/63939/files has added FP16 support to torchscript. This is to add BF16 device type when doing full conversion. Test Plan: Unit test. Also tested BF16 locally on A100 using MLP model. Reviewed By: idning Differential Revision: D32027152 fbshipit-source-id: b2a5ff2b22ea1e02306b0399f2b39b8493be4f45
This commit is contained in:
committed by
Facebook GitHub Bot
parent
05e17e7ff6
commit
fddfb81dd0
@ -32,9 +32,24 @@ class TestAutocast(JitTestCase):
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
with autocast():
|
||||
return torch.mm(a, b)
|
||||
result = fn(self.a_fp32, self.b_fp32)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
x = torch.mm(a, b)
|
||||
y = torch.sum(x)
|
||||
return x, y
|
||||
x, y = fn(self.a_fp32, self.b_fp32)
|
||||
self.assertEqual(x.dtype, torch.float16)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_linear_bf16(self):
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
x = torch.mm(a, b)
|
||||
y = torch.sum(x)
|
||||
return x, y
|
||||
x, y = fn(self.a_fp32, self.b_fp32)
|
||||
self.assertEqual(x.dtype, torch.bfloat16)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_minimal_cpu(self):
|
||||
|
Reference in New Issue
Block a user