Add Half support for CPU autocast on eager mode (#112484)

Add Half support for CPU autocast on eager mode since common operators have Half support on CPU.
https://github.com/pytorch/pytorch/issues/96093.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484
Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
This commit is contained in:
CaoE
2023-11-21 20:08:28 +00:00
committed by PyTorch MergeBot
parent 4e4a6ad6ec
commit c47d2b8035
3 changed files with 77 additions and 38 deletions

View File

@ -17,7 +17,16 @@ class TestAutocastCPU(TestCase):
del self.autocast_lists
super().tearDown()
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
@ -31,7 +40,7 @@ class TestAutocastCPU(TestCase):
add_kwargs = {}
self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast():
with torch.cpu.amp.autocast(dtype=amp_dtype):
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
@ -92,36 +101,61 @@ class TestAutocastCPU(TestCase):
return op_with_args[0], op_with_args[1], op_with_args[2]
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
def test_autocast_methods_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_bf16:
def test_autocast_torch_16(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
def test_autocast_nn_bf16(self):
for op_with_args in self.autocast_lists.nn_bf16:
def test_autocast_nn_16(self):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
self._run_autocast_outofplace(
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
def test_autocast_rnn(self):

View File

@ -257,11 +257,12 @@ class autocast:
self._cache_enabled = cache_enabled
if self.device == "cpu":
supported_dtype = [torch.bfloat16]
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
"CPU Autocast only supports dtype of torch.bfloat16 currently."
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False

View File

@ -244,6 +244,9 @@ class AutocastCPUTestLists:
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
@ -275,29 +278,30 @@ class AutocastCPUTestLists:
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_bf16 = [
self.torch_16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
@ -337,7 +341,7 @@ class AutocastCPUTestLists:
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
]
self.nn_bf16 = [
self.nn_16 = [
("linear", mat0_fp32 + mat1_fp32, {}),
]
self.nn_fp32 = [
@ -358,6 +362,6 @@ class AutocastCPUTestLists:
("huber_loss", mat0_bf16 + mat1_bf16),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,)),
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
]