mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991. Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
373 lines
14 KiB
Python
373 lines
14 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
import collections
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
|
|
from torch.testing._internal.common_utils import (
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
class TestAutocastCPU(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
|
|
|
|
def tearDown(self):
|
|
del self.autocast_lists
|
|
super().tearDown()
|
|
|
|
def _run_autocast_outofplace(
|
|
self,
|
|
op,
|
|
args,
|
|
run_as_type,
|
|
out_type=None,
|
|
module=torch,
|
|
add_kwargs=None,
|
|
amp_dtype=torch.bfloat16,
|
|
):
|
|
# helper to cast args
|
|
def cast(val, to_type):
|
|
if isinstance(val, torch.Tensor):
|
|
return val.to(to_type) if val.is_floating_point() else val
|
|
elif isinstance(val, collections.abc.Iterable):
|
|
return type(val)(cast(v, to_type) for v in val)
|
|
else:
|
|
return val
|
|
|
|
if add_kwargs is None:
|
|
add_kwargs = {}
|
|
|
|
self.assertFalse(torch.is_autocast_cpu_enabled())
|
|
with torch.cpu.amp.autocast(dtype=amp_dtype):
|
|
self.assertTrue(torch.is_autocast_cpu_enabled())
|
|
out_type = out_type if out_type is not None else run_as_type
|
|
output = output_method = None
|
|
|
|
# Try module.* variant, if requested:
|
|
if module is not None and hasattr(module, op):
|
|
output = getattr(module, op)(*args, **add_kwargs)
|
|
if isinstance(output, torch.Tensor):
|
|
self.assertTrue(
|
|
out_type == output.dtype,
|
|
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
|
|
)
|
|
# Try Tensor.* variant:
|
|
if hasattr(torch.Tensor, op):
|
|
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
|
|
if isinstance(output_method, torch.Tensor):
|
|
self.assertTrue(
|
|
out_type == output_method.dtype,
|
|
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
|
|
)
|
|
|
|
self.assertTrue(
|
|
(output is not None) or (output_method is not None),
|
|
f"{op} not found as an attribute on either Tensor or the requested module {module}",
|
|
)
|
|
|
|
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
|
|
# For example, lstm_cell returns a tuple and equal returns bool.
|
|
def compare(first, second):
|
|
if isinstance(first, torch.Tensor):
|
|
return torch.equal(first, second)
|
|
elif isinstance(first, collections.abc.Iterable):
|
|
return all(compare(f, s) for f, s in zip(first, second))
|
|
else:
|
|
return first == second
|
|
|
|
# If both torch.* and Tensor.* variants were found, check outputs are identical
|
|
if (output is not None) and (output_method is not None):
|
|
self.assertTrue(type(output) == type(output_method))
|
|
comparison = compare(output, output_method)
|
|
self.assertTrue(
|
|
comparison, f"torch.{op} result did not match Tensor.{op} result"
|
|
)
|
|
|
|
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
|
|
# as the C++-side autocasting, and should be bitwise accurate.
|
|
output_to_compare = output if output is not None else output_method
|
|
with torch.cpu.amp.autocast(enabled=False):
|
|
self.assertFalse(torch.is_autocast_cpu_enabled())
|
|
|
|
if module is not None and hasattr(module, op):
|
|
control = getattr(module, op)(
|
|
*cast(args, run_as_type), **add_kwargs
|
|
)
|
|
else:
|
|
control = getattr(args[0].to(run_as_type), op)(
|
|
*cast(args[1:], run_as_type), **add_kwargs
|
|
)
|
|
self.assertTrue(type(output_to_compare) == type(control))
|
|
comparison = compare(output_to_compare, control)
|
|
self.assertTrue(comparison, f"torch.{op} result did not match control")
|
|
self.assertTrue(torch.is_autocast_cpu_enabled())
|
|
self.assertFalse(torch.is_autocast_cpu_enabled())
|
|
|
|
def args_maybe_kwargs(self, op_with_args):
|
|
if len(op_with_args) == 2:
|
|
return op_with_args[0], op_with_args[1], {}
|
|
else:
|
|
return op_with_args[0], op_with_args[1], op_with_args[2]
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_torch_expect_builtin_promote(self):
|
|
for (
|
|
op,
|
|
args1,
|
|
args2,
|
|
out_type,
|
|
) in self.autocast_lists.torch_expect_builtin_promote:
|
|
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
|
|
self._run_autocast_outofplace(
|
|
op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_methods_expect_builtin_promote(self):
|
|
for (
|
|
op,
|
|
args1,
|
|
args2,
|
|
out_type,
|
|
) in self.autocast_lists.methods_expect_builtin_promote:
|
|
self._run_autocast_outofplace(
|
|
op, args1, torch.float32, module=None, out_type=out_type
|
|
)
|
|
self._run_autocast_outofplace(
|
|
op,
|
|
args2,
|
|
torch.float32,
|
|
module=None,
|
|
out_type=out_type,
|
|
amp_dtype=torch.float16,
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_torch_16(self):
|
|
for op_with_args in self.autocast_lists.torch_16:
|
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
|
self._run_autocast_outofplace(
|
|
op, args, torch.bfloat16, add_kwargs=maybe_kwargs
|
|
)
|
|
self._run_autocast_outofplace(
|
|
op,
|
|
args,
|
|
torch.float16,
|
|
add_kwargs=maybe_kwargs,
|
|
amp_dtype=torch.float16,
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_nn_16(self):
|
|
for op_with_args in self.autocast_lists.nn_16:
|
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
|
self._run_autocast_outofplace(
|
|
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
|
|
)
|
|
self._run_autocast_outofplace(
|
|
op,
|
|
args,
|
|
torch.float16,
|
|
module=torch._C._nn,
|
|
add_kwargs=maybe_kwargs,
|
|
amp_dtype=torch.float16,
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_torch_fp32(self):
|
|
for op_with_args in self.autocast_lists.torch_fp32:
|
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
|
self._run_autocast_outofplace(
|
|
op, args, torch.float32, add_kwargs=maybe_kwargs
|
|
)
|
|
self._run_autocast_outofplace(
|
|
op,
|
|
args,
|
|
torch.float32,
|
|
add_kwargs=maybe_kwargs,
|
|
amp_dtype=torch.float16,
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_nn_fp32(self):
|
|
for op_with_args in self.autocast_lists.nn_fp32:
|
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
|
self._run_autocast_outofplace(
|
|
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
|
|
)
|
|
self._run_autocast_outofplace(
|
|
op,
|
|
args,
|
|
torch.float32,
|
|
module=torch._C._nn,
|
|
add_kwargs=maybe_kwargs,
|
|
amp_dtype=torch.float16,
|
|
)
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_autocast_torch_need_autocast_promote(self):
|
|
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
|
|
self._run_autocast_outofplace(op, args1, torch.float32)
|
|
self._run_autocast_outofplace(
|
|
op, args2, torch.float32, amp_dtype=torch.float16
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
|
|
def test_autocast_rnn(self):
|
|
if (
|
|
torch.backends.mkldnn.is_available()
|
|
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
|
):
|
|
x = torch.randn(1, 2, 1)
|
|
hx = torch.randn(2, 2, 1)
|
|
cx = torch.randn(2, 2, 1)
|
|
|
|
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
|
|
|
|
# Raise ValueError when autocast is not enabled
|
|
with self.assertRaisesRegex(ValueError, "input must have the type"):
|
|
m(x, (hx, cx))
|
|
|
|
# Should be able to run the below case with autocast
|
|
with torch.cpu.amp.autocast():
|
|
m(x, (hx, cx))
|
|
|
|
def test_autocast_disabled_with_fp32_dtype(self):
|
|
with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
|
|
_ = torch.ones(10)
|
|
|
|
def test_generic_autocast(self):
|
|
for op_with_args in self.autocast_lists.torch_16:
|
|
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
|
|
with torch.amp.autocast(device_type="cpu"):
|
|
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
|
with torch.cpu.amp.autocast():
|
|
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
|
|
self.assertEqual(generic_autocast_output, cpu_autocast_output)
|
|
|
|
def test_cpu_autocast_deprecated_warning(self):
|
|
with self.assertWarnsRegex(
|
|
DeprecationWarning,
|
|
r"torch.cpu.amp.autocast\(args...\) is deprecated. Please use torch.amp.autocast\('cpu', args...\) instead.",
|
|
):
|
|
with torch.cpu.amp.autocast():
|
|
_ = torch.ones(10)
|
|
|
|
|
|
class CustomLinear(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, w_t):
|
|
ctx.save_for_backward(x, w_t)
|
|
return torch.nn.functional.linear(x, w_t)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, w_t = ctx.saved_tensors
|
|
with torch.autocast(device_type="cuda"):
|
|
dL_dX = torch.matmul(grad_output, w_t)
|
|
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
|
|
return dL_dX, dL_dW
|
|
|
|
|
|
class WeightDTypeCastCounterMode(TorchDispatchMode):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.dtype_cast_counter = 0
|
|
self.weight = weight
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if (
|
|
func is torch.ops.aten._to_copy.default
|
|
and args[0] is self.weight
|
|
and kwargs["dtype"] is torch.float16
|
|
):
|
|
self.dtype_cast_counter += 1
|
|
return func(*args, **kwargs)
|
|
|
|
def __enter__(self):
|
|
self.old_clear_cache = torch.clear_autocast_cache
|
|
torch.clear_autocast_cache = lambda: None
|
|
return super().__enter__()
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
torch.clear_autocast_cache = self.old_clear_cache
|
|
return super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
class TestAutocastGPU(TestCase):
|
|
def test_cast_cache_is_global(self):
|
|
"""
|
|
Verifies that the autocast cache is global. This is done by
|
|
mocking out cache clearing at the end of the forward pass,
|
|
running forward+backward with an explicit call to autocast in the
|
|
backward, and verifying that the weight only get cast to float16 once.
|
|
"""
|
|
|
|
data = torch.randn(2, 3).cuda()
|
|
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
|
|
|
|
with WeightDTypeCastCounterMode(weight) as mode:
|
|
with torch.autocast(device_type="cuda"):
|
|
output = CustomLinear.apply(data, weight)
|
|
s = output.sum()
|
|
s.backward()
|
|
|
|
self.assertEqual(mode.dtype_cast_counter, 1)
|
|
|
|
def test_cache_disabled(self):
|
|
data = torch.randn(2, 3).cuda()
|
|
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
|
|
|
|
try:
|
|
torch._C._set_cached_tensors_enabled(True)
|
|
torch._C._add_cached_tensor(weight)
|
|
|
|
with WeightDTypeCastCounterMode(weight) as mode:
|
|
with torch.autocast(device_type="cuda"):
|
|
output = CustomLinear.apply(data, weight)
|
|
s = output.sum()
|
|
s.backward()
|
|
|
|
# we should not have cached the conversion of the weight
|
|
self.assertEqual(mode.dtype_cast_counter, 2)
|
|
|
|
finally:
|
|
torch._C._set_cached_tensors_enabled(False)
|
|
|
|
|
|
class TestTorchAutocast(TestCase):
|
|
def test_autocast_fast_dtype(self):
|
|
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
|
|
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
|
|
self.assertEqual(gpu_fast_dtype, torch.half)
|
|
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
|
|
|
|
def test_invalid_device(self):
|
|
dev = "not a real device"
|
|
msg = f"Invalid device string: '{dev}'"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
with torch.autocast(device_type=dev):
|
|
_ = torch.tensor(1)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
assert torch.amp.is_autocast_available(device_type=dev)
|
|
|
|
def test_non_string_device(self):
|
|
"""Test that `autocast` throws a ValueError when provided a `torch.device` object for `device_type` instead of a string"""
|
|
dev = torch.device("cpu")
|
|
msg = f"Expected `device_type` of type `str`, got: `{type(dev)}`"
|
|
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
|
torch.autocast(device_type=dev)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|