mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add Half support for CPU autocast on eager mode (#112484)
Add Half support for CPU autocast on eager mode since common operators have Half support on CPU. https://github.com/pytorch/pytorch/issues/96093. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484 Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
This commit is contained in:
@ -17,7 +17,16 @@ class TestAutocastCPU(TestCase):
|
|||||||
del self.autocast_lists
|
del self.autocast_lists
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
|
|
||||||
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
|
def _run_autocast_outofplace(
|
||||||
|
self,
|
||||||
|
op,
|
||||||
|
args,
|
||||||
|
run_as_type,
|
||||||
|
out_type=None,
|
||||||
|
module=torch,
|
||||||
|
add_kwargs=None,
|
||||||
|
amp_dtype=torch.bfloat16,
|
||||||
|
):
|
||||||
# helper to cast args
|
# helper to cast args
|
||||||
def cast(val, to_type):
|
def cast(val, to_type):
|
||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
@ -31,7 +40,7 @@ class TestAutocastCPU(TestCase):
|
|||||||
add_kwargs = {}
|
add_kwargs = {}
|
||||||
|
|
||||||
self.assertFalse(torch.is_autocast_cpu_enabled())
|
self.assertFalse(torch.is_autocast_cpu_enabled())
|
||||||
with torch.cpu.amp.autocast():
|
with torch.cpu.amp.autocast(dtype=amp_dtype):
|
||||||
self.assertTrue(torch.is_autocast_cpu_enabled())
|
self.assertTrue(torch.is_autocast_cpu_enabled())
|
||||||
out_type = out_type if out_type is not None else run_as_type
|
out_type = out_type if out_type is not None else run_as_type
|
||||||
output = output_method = None
|
output = output_method = None
|
||||||
@ -92,36 +101,61 @@ class TestAutocastCPU(TestCase):
|
|||||||
return op_with_args[0], op_with_args[1], op_with_args[2]
|
return op_with_args[0], op_with_args[1], op_with_args[2]
|
||||||
|
|
||||||
def test_autocast_torch_expect_builtin_promote(self):
|
def test_autocast_torch_expect_builtin_promote(self):
|
||||||
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
|
||||||
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
|
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
|
||||||
|
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
|
||||||
|
|
||||||
def test_autocast_methods_expect_builtin_promote(self):
|
def test_autocast_methods_expect_builtin_promote(self):
|
||||||
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
|
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
|
||||||
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
|
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
|
||||||
|
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
|
||||||
|
|
||||||
def test_autocast_torch_bf16(self):
|
def test_autocast_torch_16(self):
|
||||||
for op_with_args in self.autocast_lists.torch_bf16:
|
for op_with_args in self.autocast_lists.torch_16:
|
||||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||||
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
|
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
|
||||||
|
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||||
|
|
||||||
def test_autocast_nn_bf16(self):
|
def test_autocast_nn_16(self):
|
||||||
for op_with_args in self.autocast_lists.nn_bf16:
|
for op_with_args in self.autocast_lists.nn_16:
|
||||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||||
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
|
self._run_autocast_outofplace(
|
||||||
|
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||||
|
)
|
||||||
|
self._run_autocast_outofplace(
|
||||||
|
op,
|
||||||
|
args,
|
||||||
|
torch.float16,
|
||||||
|
module=torch._C._nn,
|
||||||
|
add_kwargs=maybe_kwargs,
|
||||||
|
amp_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
def test_autocast_torch_fp32(self):
|
def test_autocast_torch_fp32(self):
|
||||||
for op_with_args in self.autocast_lists.torch_fp32:
|
for op_with_args in self.autocast_lists.torch_fp32:
|
||||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||||
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
|
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
|
||||||
|
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
|
||||||
|
|
||||||
def test_autocast_nn_fp32(self):
|
def test_autocast_nn_fp32(self):
|
||||||
for op_with_args in self.autocast_lists.nn_fp32:
|
for op_with_args in self.autocast_lists.nn_fp32:
|
||||||
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
||||||
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
|
self._run_autocast_outofplace(
|
||||||
|
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
|
||||||
|
)
|
||||||
|
self._run_autocast_outofplace(
|
||||||
|
op,
|
||||||
|
args,
|
||||||
|
torch.float32,
|
||||||
|
module=torch._C._nn,
|
||||||
|
add_kwargs=maybe_kwargs,
|
||||||
|
amp_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
def test_autocast_torch_need_autocast_promote(self):
|
def test_autocast_torch_need_autocast_promote(self):
|
||||||
for op, args in self.autocast_lists.torch_need_autocast_promote:
|
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
|
||||||
self._run_autocast_outofplace(op, args, torch.float32)
|
self._run_autocast_outofplace(op, args1, torch.float32)
|
||||||
|
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
||||||
def test_autocast_rnn(self):
|
def test_autocast_rnn(self):
|
||||||
|
@ -257,11 +257,12 @@ class autocast:
|
|||||||
self._cache_enabled = cache_enabled
|
self._cache_enabled = cache_enabled
|
||||||
|
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
supported_dtype = [torch.bfloat16]
|
supported_dtype = [torch.bfloat16, torch.float16]
|
||||||
if self.fast_dtype not in supported_dtype and enabled:
|
if self.fast_dtype not in supported_dtype and enabled:
|
||||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||||
|
error_message += "CPU Autocast only supports dtype of "
|
||||||
error_message += (
|
error_message += (
|
||||||
"CPU Autocast only supports dtype of torch.bfloat16 currently."
|
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||||
)
|
)
|
||||||
warnings.warn(error_message)
|
warnings.warn(error_message)
|
||||||
enabled = False
|
enabled = False
|
||||||
|
@ -244,6 +244,9 @@ class AutocastCPUTestLists:
|
|||||||
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
|
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
|
||||||
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
|
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
|
||||||
|
|
||||||
|
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
|
||||||
|
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
|
||||||
|
|
||||||
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
|
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
|
||||||
|
|
||||||
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
|
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
|
||||||
@ -275,29 +278,30 @@ class AutocastCPUTestLists:
|
|||||||
# Some ops implement built-in type promotion. These don't need autocasting,
|
# Some ops implement built-in type promotion. These don't need autocasting,
|
||||||
# but autocasting relies on their promotion, so we include tests to double-check.
|
# but autocasting relies on their promotion, so we include tests to double-check.
|
||||||
self.torch_expect_builtin_promote = [
|
self.torch_expect_builtin_promote = [
|
||||||
("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("le", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("add", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
("div", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.methods_expect_builtin_promote = [
|
self.methods_expect_builtin_promote = [
|
||||||
("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
|
("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
|
||||||
("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
|
("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
|
||||||
]
|
]
|
||||||
# The remaining lists organize ops that autocast treats explicitly.
|
# The remaining lists organize ops that autocast treats explicitly.
|
||||||
self.torch_bf16 = [
|
self.torch_16 = [
|
||||||
("conv1d", conv_args_fp32[0]),
|
("conv1d", conv_args_fp32[0]),
|
||||||
("conv2d", conv_args_fp32[1]),
|
("conv2d", conv_args_fp32[1]),
|
||||||
("conv3d", conv_args_fp32[2]),
|
("conv3d", conv_args_fp32[2]),
|
||||||
@ -337,7 +341,7 @@ class AutocastCPUTestLists:
|
|||||||
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
|
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
|
||||||
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
|
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
|
||||||
]
|
]
|
||||||
self.nn_bf16 = [
|
self.nn_16 = [
|
||||||
("linear", mat0_fp32 + mat1_fp32, {}),
|
("linear", mat0_fp32 + mat1_fp32, {}),
|
||||||
]
|
]
|
||||||
self.nn_fp32 = [
|
self.nn_fp32 = [
|
||||||
@ -358,6 +362,6 @@ class AutocastCPUTestLists:
|
|||||||
("huber_loss", mat0_bf16 + mat1_bf16),
|
("huber_loss", mat0_bf16 + mat1_bf16),
|
||||||
]
|
]
|
||||||
self.torch_need_autocast_promote = [
|
self.torch_need_autocast_promote = [
|
||||||
("cat", (pointwise0_bf16 + pointwise1_fp32,)),
|
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
|
||||||
("stack", (pointwise0_bf16 + pointwise1_fp32,)),
|
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user