Update on "[HOP][print][dynamo]Add dynamo for hop print"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
This commit is contained in:
Xiao Fu
2025-11-13 14:57:42 -08:00

View File

@ -4,7 +4,6 @@ import unittest
from unittest.mock import patch
import torch
from torch._dynamo.testing import same
from torch._functorch.aot_autograd import aot_export_module
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@ -203,18 +202,19 @@ class TestHopPrintInDynamo(TestCase):
return (x1, x3)
x = torch.ones(3, 3)
# Eager backend for dynamo tracing testing
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
# Eager and aot_eager backend for dynamo tracing testing
for be in ["eager", "aot_eager"]:
opt_f = torch.compile(backend=be, fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(
printed_output,
f"moo {torch.ones(3, 3) * 2}\nmoo {torch.ones(3, 3) * 2 * torch.ones(3, 3) * 2}",
)
self.assertTrue(same(orig_out, opt_out))
self.assertEqual(
printed_output,
f"moo {torch.ones(3, 3) * 2}\nmoo {torch.ones(3, 3) * 2 * torch.ones(3, 3) * 2}",
)
self.assertEqual(orig_out, opt_out)
def test_constant_mutation(self):
def f(x):
@ -228,37 +228,15 @@ class TestHopPrintInDynamo(TestCase):
return res
inputs = (torch.tensor([1]),)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
for be in ["eager", "aot_eager"]:
opt_f = torch.compile(backend=be, fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])")
self.assertTrue(same(orig_out, opt_out))
def test_print_full_graph(self):
def fn(a, b):
torch._higher_order_ops.print("print hop {x} {y}", x=a, y=b)
return torch.sin(a, out=b)
inp = [torch.randn(3, 3), torch.ones(3, 3)]
ref_out = fn(*inp)
# Validate the hop print can reduce the graph break in dynamo tracing
out = torch.compile(fn, fullgraph=True)(*inp)
self.assertEqual(ref_out, out)
# aot_eager backend for dynamo tracing testing with hop functionalization impl
aote_f = torch.compile(backend="aot_eager")(fn)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = aote_f(*inp)
printed_output = mock_stdout.getvalue().strip()
self.assertTrue(same(ref_out, opt_out))
self.assertEqual(
printed_output,
f"print hop {inp[0]} {inp[1]}",
)
self.assertEqual(printed_output, "moo tensor([2])\nmoo tensor([1])")
self.assertEqual(orig_out, opt_out)
if __name__ == "__main__":