Add Half support for CPU autocast on eager mode (#112484)

Add Half support for CPU autocast on eager mode since common operators have Half support on CPU.
https://github.com/pytorch/pytorch/issues/96093.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484
Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
This commit is contained in:
CaoE
2023-11-21 20:08:28 +00:00
committed by PyTorch MergeBot
parent 4e4a6ad6ec
commit c47d2b8035
3 changed files with 77 additions and 38 deletions

View File

@ -17,7 +17,16 @@ class TestAutocastCPU(TestCase):
del self.autocast_lists
super().tearDown()
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
@ -31,7 +40,7 @@ class TestAutocastCPU(TestCase):
add_kwargs = {}
self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast():
with torch.cpu.amp.autocast(dtype=amp_dtype):
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
@ -92,36 +101,61 @@ class TestAutocastCPU(TestCase):
return op_with_args[0], op_with_args[1], op_with_args[2]
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
def test_autocast_methods_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_bf16:
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
def test_autocast_nn_bf16(self):
for op_with_args in self.autocast_lists.nn_bf16:
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_autocast_rnn(self):