mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half support for CPU autocast on eager mode (#112484)
Add Half support for CPU autocast on eager mode since common operators have Half support on CPU. https://github.com/pytorch/pytorch/issues/96093. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484 Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
This commit is contained in:
@ -17,7 +17,16 @@ class TestAutocastCPU(TestCase):
|
||||
del self.autocast_lists
|
||||
super().tearDown()
|
||||
|
||||
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
|
||||
def _run_autocast_outofplace(
|
||||
self,
|
||||
op,
|
||||
args,
|
||||
run_as_type,
|
||||
out_type=None,
|
||||
module=torch,
|
||||
add_kwargs=None,
|
||||
amp_dtype=torch.bfloat16,
|
||||
):
|
||||
# helper to cast args
|
||||
def cast(val, to_type):
|
||||
if isinstance(val, torch.Tensor):
|
||||
@ -31,7 +40,7 @@ class TestAutocastCPU(TestCase):
|
||||
add_kwargs = {}
|
||||
|
||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||
with torch.cpu.amp.autocast():
|
||||
with torch.cpu.amp.autocast(dtype=amp_dtype):
|
||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
||||
out_type = out_type if out_type is not None else run_as_type
|
||||
output = output_method = None
|
||||
@ -92,36 +101,61 @@ class TestAutocastCPU(TestCase):
|
||||
return op_with_args[0], op_with_args[1], op_with_args[2]
|
||||
|
||||
def test_autocast_torch_expect_builtin_promote(self):
|
||||
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
|
||||
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
|
||||
|
||||
def test_autocast_methods_expect_builtin_promote(self):
|
||||
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
|
||||
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
|
||||
|
||||
def test_autocast_torch_bf16(self):
|
||||
for op_with_args in self.autocast_lists.torch_bf16:
|
||||
def test_autocast_torch_16(self):
|
||||
for op_with_args in self.autocast_lists.torch_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||
|
||||
def test_autocast_nn_bf16(self):
|
||||
for op_with_args in self.autocast_lists.nn_bf16:
|
||||
def test_autocast_nn_16(self):
|
||||
for op_with_args in self.autocast_lists.nn_16:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float16,
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
def test_autocast_torch_fp32(self):
|
||||
for op_with_args in self.autocast_lists.torch_fp32:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||
|
||||
def test_autocast_nn_fp32(self):
|
||||
for op_with_args in self.autocast_lists.nn_fp32:
|
||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
|
||||
self._run_autocast_outofplace(
|
||||
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||
)
|
||||
self._run_autocast_outofplace(
|
||||
op,
|
||||
args,
|
||||
torch.float32,
|
||||
module=torch._C._nn,
|
||||
add_kwargs=maybe_kwargs,
|
||||
amp_dtype=torch.float16,
|
||||
)
|
||||
|
||||
def test_autocast_torch_need_autocast_promote(self):
|
||||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args, torch.float32)
|
||||
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
|
||||
self._run_autocast_outofplace(op, args1, torch.float32)
|
||||
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||
def test_autocast_rnn(self):
|
||||
|
Reference in New Issue
Block a user