Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134460
Approved by: https://github.com/EikanWang, https://github.com/ezyang
This commit is contained in:
FFFrog
2024-09-04 07:53:09 +00:00
committed by PyTorch MergeBot
parent c2ff9fe042
commit 80a6d60829
4 changed files with 697 additions and 750 deletions

View File

@ -1,10 +1,12 @@
# Owner(s): ["module: unknown"]
import collections
import unittest
import torch
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.testing._internal.autocast_test_lists import (
AutocastCPUTestLists,
TestAutocast,
)
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
@ -14,7 +16,7 @@ from torch.testing._internal.common_utils import (
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestCase):
class TestAutocastCPU(TestAutocast):
def setUp(self):
super().setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device("cpu"))
@ -23,100 +25,6 @@ class TestAutocastCPU(TestCase):
del self.autocast_lists
super().tearDown()
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
with torch.amp.autocast(device_type="cpu", dtype=amp_dtype):
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(
out_type == output.dtype,
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
)
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(
out_type == output_method.dtype,
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
)
self.assertTrue(
(output is not None) or (output_method is not None),
f"{op} not found as an attribute on either Tensor or the requested module {module}",
)
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(
comparison, f"torch.{op} result did not match Tensor.{op} result"
)
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.amp.autocast(device_type="cpu", enabled=False):
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
if module is not None and hasattr(module, op):
control = getattr(module, op)(
*cast(args, run_as_type), **add_kwargs
)
else:
control = getattr(args[0].to(run_as_type), op)(
*cast(args[1:], run_as_type), **add_kwargs
)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, f"torch.{op} result did not match control")
self.assertTrue(torch.is_autocast_enabled(device_type="cpu"))
self.assertFalse(torch.is_autocast_enabled(device_type="cpu"))
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
@skipIfTorchDynamo()
def test_autocast_torch_expect_builtin_promote(self):
for (
@ -125,9 +33,16 @@ class TestAutocastCPU(TestCase):
args2,
out_type,
) in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
self._run_autocast_outofplace(
op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16
op, args1, torch.float32, device="cpu", out_type=out_type
)
self._run_autocast_outofplace(
op,
args2,
torch.float32,
device="cpu",
out_type=out_type,
amp_dtype=torch.float16,
)
@skipIfTorchDynamo()
@ -139,12 +54,13 @@ class TestAutocastCPU(TestCase):
out_type,
) in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(
op, args1, torch.float32, module=None, out_type=out_type
op, args1, torch.float32, device="cpu", module=None, out_type=out_type
)
self._run_autocast_outofplace(
op,
args2,
torch.float32,
device="cpu",
module=None,
out_type=out_type,
amp_dtype=torch.float16,
@ -155,12 +71,13 @@ class TestAutocastCPU(TestCase):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.bfloat16, add_kwargs=maybe_kwargs
op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@ -170,12 +87,18 @@ class TestAutocastCPU(TestCase):
for op_with_args in self.autocast_lists.nn_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
op,
args,
torch.bfloat16,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float16,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
@ -186,12 +109,13 @@ class TestAutocastCPU(TestCase):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.float32, add_kwargs=maybe_kwargs
op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
device="cpu",
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
)
@ -201,12 +125,18 @@ class TestAutocastCPU(TestCase):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
op,
args,
torch.float32,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
)
self._run_autocast_outofplace(
op,
args,
torch.float32,
device="cpu",
module=torch._C._nn,
add_kwargs=maybe_kwargs,
amp_dtype=torch.float16,
@ -215,9 +145,9 @@ class TestAutocastCPU(TestCase):
@skipIfTorchDynamo()
def test_autocast_torch_need_autocast_promote(self):
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args1, torch.float32)
self._run_autocast_outofplace(op, args1, torch.float32, device="cpu")
self._run_autocast_outofplace(
op, args2, torch.float32, amp_dtype=torch.float16
op, args2, torch.float32, device="cpu", amp_dtype=torch.float16
)
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")