diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py index bb08cd6d3583..bbbd2a1317c4 100644 --- a/test/higher_order_ops/test_print.py +++ b/test/higher_order_ops/test_print.py @@ -4,7 +4,6 @@ import unittest from unittest.mock import patch import torch -from torch._dynamo.testing import same from torch._functorch.aot_autograd import aot_export_module from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import run_tests, TestCase @@ -203,18 +202,19 @@ class TestHopPrintInDynamo(TestCase): return (x1, x3) x = torch.ones(3, 3) - # Eager backend for dynamo tracing testing - opt_f = torch.compile(backend="eager", fullgraph=True)(f) - with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: - opt_out = opt_f(x) - printed_output = mock_stdout.getvalue().strip() - orig_out = f(x) + # Eager and aot_eager backend for dynamo tracing testing + for be in ["eager", "aot_eager"]: + opt_f = torch.compile(backend=be, fullgraph=True)(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(x) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(x) - self.assertEqual( - printed_output, - f"moo {torch.ones(3, 3) * 2}\nmoo {torch.ones(3, 3) * 2 * torch.ones(3, 3) * 2}", - ) - self.assertTrue(same(orig_out, opt_out)) + self.assertEqual( + printed_output, + f"moo {torch.ones(3, 3) * 2}\nmoo {torch.ones(3, 3) * 2 * torch.ones(3, 3) * 2}", + ) + self.assertEqual(orig_out, opt_out) def test_constant_mutation(self): def f(x): @@ -228,37 +228,15 @@ class TestHopPrintInDynamo(TestCase): return res inputs = (torch.tensor([1]),) - opt_f = torch.compile(backend="eager", fullgraph=True)(f) - with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: - opt_out = opt_f(*inputs) - printed_output = mock_stdout.getvalue().strip() - orig_out = f(*inputs) + for be in ["eager", "aot_eager"]: + opt_f = torch.compile(backend=be, fullgraph=True)(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(*inputs) + printed_output = mock_stdout.getvalue().strip() + orig_out = f(*inputs) - self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])") - self.assertTrue(same(orig_out, opt_out)) - - def test_print_full_graph(self): - def fn(a, b): - torch._higher_order_ops.print("print hop {x} {y}", x=a, y=b) - return torch.sin(a, out=b) - - inp = [torch.randn(3, 3), torch.ones(3, 3)] - ref_out = fn(*inp) - # Validate the hop print can reduce the graph break in dynamo tracing - out = torch.compile(fn, fullgraph=True)(*inp) - self.assertEqual(ref_out, out) - - # aot_eager backend for dynamo tracing testing with hop functionalization impl - aote_f = torch.compile(backend="aot_eager")(fn) - with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: - opt_out = aote_f(*inp) - printed_output = mock_stdout.getvalue().strip() - - self.assertTrue(same(ref_out, opt_out)) - self.assertEqual( - printed_output, - f"print hop {inp[0]} {inp[1]}", - ) + self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])") + self.assertEqual(orig_out, opt_out) if __name__ == "__main__":