Add BF16 type to _autocast_to_full_precision (#67707)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67707

https://github.com/pytorch/pytorch/pull/63939/files has added FP16 support to torchscript.

This is to add BF16 device type when doing full conversion.

Test Plan: Unit test. Also tested BF16 locally on A100 using MLP model.

Reviewed By: idning

Differential Revision: D32027152

fbshipit-source-id: b2a5ff2b22ea1e02306b0399f2b39b8493be4f45
This commit is contained in:
Yusuo Hu
2021-11-03 14:05:23 -07:00
committed by Facebook GitHub Bot
parent 05e17e7ff6
commit fddfb81dd0
2 changed files with 19 additions and 4 deletions

View File

@ -142,7 +142,7 @@ Tensor _autocast_to_reduced_precision(const Tensor& self, bool cuda_enabled, boo
// If input tensor is fp16, cast it to fp32, otherwise leave it alone.
// (this is intended to be used internally by the JIT autocast implementation)
Tensor _autocast_to_full_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled) {
if (self.dtype() == at::ScalarType::Half &&
if ((self.dtype() == at::ScalarType::Half || self.dtype() == at::ScalarType::BFloat16) &&
((self.device().is_cuda() && cuda_enabled) ||
(self.device().is_cpu() && cpu_enabled))
) {

View File

@ -32,9 +32,24 @@ class TestAutocast(JitTestCase):
@torch.jit.script
def fn(a, b):
with autocast():
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.float16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_linear_bf16(self):
@torch.jit.script
def fn(a, b):
with autocast(dtype=torch.bfloat16):
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.bfloat16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal_cpu(self):