Compare commits

...

46 Commits

Author SHA1 Message Date
360fca9683 Merge branch 'gh/fxdawnn/10/head' into HOPrintFunc 2025-11-07 13:48:32 -08:00
74442c584d [HOP][print] Add witheffect for print by adding schema for it to pass
ghstack-source-id: 227bf11498a24d0ccd0dfed80e43307006d67324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167016
2025-11-07 13:27:26 -08:00
c97873141d Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-07 10:10:22 -08:00
b13e39d672 [HOP][print] Add witheffect for print
ghstack-source-id: 227bf11498a24d0ccd0dfed80e43307006d67324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167016
2025-11-07 10:10:22 -08:00
a0fa174c27 Update base for Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-07 10:10:22 -08:00
69ee86ed16 Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-07 10:08:51 -08:00
bff0bba0a0 Update base for Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-07 10:08:51 -08:00
3c4fd0550a Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc albanD bdhirsh ezyang

[ghstack-poisoned]
2025-11-06 15:28:49 -08:00
7e1a27e9a1 Update base for Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc albanD bdhirsh ezyang

[ghstack-poisoned]
2025-11-06 15:28:49 -08:00
21b28c13c2 Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-06 15:11:28 -08:00
136ca8444c Update base for Update on "[HOP][print] Add functionalization (make sure ordering) for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-06 15:11:28 -08:00
1270b30495 Fix test
ghstack-source-id: 1c3371c6cc1e32f86964b7cce8f8a0504592f06a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166920
2025-11-06 15:03:34 -08:00
bfa10fb473 Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 13:45:04 -08:00
a230cb687e Update base for Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 13:45:04 -08:00
b0c339e028 Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 12:34:15 -08:00
76917a5b51 Update base for Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 12:34:15 -08:00
0e90b43872 Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 11:51:33 -08:00
dc859bf5fd Update base for Update on "[HOP][print] Add functionalization for print"
cc bdhirsh ezyang

[ghstack-poisoned]
2025-11-05 11:51:33 -08:00
594906792f [HOP][print] Add functionalization for print
[ghstack-poisoned]
2025-11-04 14:00:37 -08:00
d18cc9f713 Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
cc ydwu4 penguinwu

[ghstack-poisoned]
2025-11-04 13:28:45 -08:00
b85bae22e6 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
cc ydwu4 penguinwu

[ghstack-poisoned]
2025-11-04 13:28:45 -08:00
1d2e446651 Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 13:10:59 -08:00
4762db5485 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 13:10:59 -08:00
4d00338b36 Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 12:59:35 -08:00
3f6addc6c0 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 12:59:35 -08:00
95d914abfe Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 11:57:19 -08:00
2fe0d12064 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 11:57:19 -08:00
233c64a091 Fix linter more
ghstack-source-id: f30b793c6e5df7209e485d800d6c6885388b1835
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166660
2025-11-04 10:58:04 -08:00
92fc049908 Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:58:04 -08:00
557496f5d2 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:58:04 -08:00
27a74e6f87 Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:43:39 -08:00
19527d9c14 Update base for Update on "[HOP][print]Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:43:39 -08:00
72a1e57917 Update on "Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:33:55 -08:00
226ad15635 Update base for Update on "Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 10:33:55 -08:00
e72a7adba1 Update on "Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 09:49:08 -08:00
f8e3b9931c Update base for Update on "Add make_fx test for the proxy with graph module print"
[ghstack-poisoned]
2025-11-04 09:49:08 -08:00
782aef6221 Add make_fx test for the proxy with graph module print
[ghstack-poisoned]
2025-11-03 17:04:01 -08:00
48a3c9edbf Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 17:04:01 -08:00
23f02bfde5 Update torch/_higher_order_ops/print.py
Co-authored-by: Angela Yi <yiangela7@gmail.com>
2025-11-03 16:44:54 -08:00
fc96f0769a Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 14:13:45 -08:00
36717f7179 Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 13:48:52 -08:00
1b33614b49 Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 12:58:00 -08:00
acbfa3faaf Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 11:14:09 -08:00
1213cd3fd0 Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 11:00:17 -08:00
51554434dc Update on "[HOP][print] Add HOP subclass for printing"
[ghstack-poisoned]
2025-11-03 10:54:46 -08:00
24c64a064c Add HOP subclass for printing
[ghstack-poisoned]
2025-10-30 13:00:00 -07:00
5 changed files with 270 additions and 7 deletions

View File

@ -0,0 +1,145 @@
# Owner(s): ["module: higher order operators"]
import io
from unittest.mock import patch
import torch
from torch._dynamo.utils import counters
from torch._functorch.aot_autograd import aot_export_module
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing import FileCheck
from torch._higher_order_ops.effects import with_effects
from torch._functorch.aot_autograd import aot_export_module
class TestHopPrint(TestCase):
def test_base_print(self):
def f(x):
x = x + x
torch._higher_order_ops.print("moo")
x = x * x
torch._higher_order_ops.print("moo")
return x
counters.clear()
x = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, "moo\nmoo")
def test_para_print(self):
def f(x):
x = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
x = x * x
return x
counters.clear()
x = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, "moo 1 2")
fx_f = make_fx(f)(x)
new_inp = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
fx_f(new_inp)
ori_printed_output = mock_stdout.getvalue().strip()
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(new_inp)
fx_printed_output = mock_stdout.getvalue().strip()
self.assertEqual(ori_printed_output, fx_printed_output)
def test_print_with_proxy_graph(self):
class M(torch.nn.Module):
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("moo {x}", x=x)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("yeehop {x}", x=x.shape[0])
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M(), tracing_mode="symbolic")(*inputs)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
print_1 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_1 = None
print_2 = torch.ops.higher_order.print('moo {x}', x = arg0_1); print_2 = None
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1)
print_3 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_3 = None
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0); arg0_1 = None
print_4 = torch.ops.higher_order.print('yeehop {x}', x = sym_size_int); sym_size_int = print_4 = None
return (add,)""",
)
def test_print_with_side_effect(self):
class M(torch.nn.Module):
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M())(*inputs)
FileCheck().check_count("torch.ops.higher_order.print", 2, exactly=True).run(
gm.code
)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1, arg1_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); arg0_1 = None
getitem = with_effects[0]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'moo {x} {y}', x = 1, y = 2); getitem = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)""",
)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
def test_print_with_input_mutations(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
res = x + x
x.add_(res)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=x, y=res)
return (res,)
inputs = (torch.randn(3),)
# With functionalization, it should appear wrapped with with_effects()
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
self.assertEqual(len(gs.input_tokens), 1)
self.assertEqual(len(gs.output_tokens), 1)
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
if __name__ == "__main__":
run_tests()

View File

@ -24,6 +24,7 @@ from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
from torch._higher_order_ops.local_map import local_map_hop
from torch._higher_order_ops.map import map
from torch._higher_order_ops.out_dtype import out_dtype
from torch._higher_order_ops.print import print
from torch._higher_order_ops.run_const_graph import run_const_graph
from torch._higher_order_ops.scan import scan
from torch._higher_order_ops.strict_mode import strict_mode
@ -75,4 +76,5 @@ __all__ = [
"map",
"while_loop_stack_output",
"local_map_hop",
"print",
]

View File

@ -6,7 +6,9 @@ from weakref import WeakKeyDictionary
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.print import print
from torch._higher_order_ops.torchbind import call_torchbind
from torch._higher_order_ops.print import print
from torch._library.fake_class_registry import FakeScriptObject
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
@ -28,6 +30,7 @@ SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
[
(torch.ops.aten._print.default, _EffectType.ORDERED),
(call_torchbind, _EffectType.ORDERED),
(print, _EffectType.ORDERED),
]
)
@ -210,6 +213,8 @@ def _get_schema(op, args) -> torch.FunctionSchema:
return op._schema
elif op == call_torchbind:
return getattr(args[0], args[1]).schema
elif op == print:
return print.schema()
else:
raise RuntimeError(f"Unable to get schema for op {op}")

View File

@ -0,0 +1,100 @@
from torch.distributed.utils import _unpack_kwargs
import builtins
import torch
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from typing import Any
class Print(HigherOrderOperator):
"""
print(format_str, **kwargs) -> None
This Higher Order Operator (HOP) provides a functional version of print for use in PyTorch graphs.
It enables format printing with named arguments, e.g., torch._higher_order_ops.print("moo {x} {y}", x=1, y=2).
This HOP enables printing without causing graph break.
"""
def __init__(self) -> None:
super().__init__("print")
def __call__(self, format_str: str, **kwargs: object) -> object:
assert isinstance(format_str, str)
return super().__call__(format_str, **kwargs)
@staticmethod
def schema() -> torch.FunctionSchema:
"""
Returns the schema of ``Print.__call__``.
"""
# print(str format_str, ...) -> ()
schema_str = "print(str format_str, ...) -> ()"
return torch._C.parse_schema(schema_str)
print = Print()
@print.py_impl(ProxyTorchDispatchMode)
# pyre-ignore
def print_proxy_torch_dispatch_mode(
mode: ProxyTorchDispatchMode, format_str: str, **kwargs: object
) -> None:
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) # type: ignore[union-attr] # noqa: F841
mode.tracer.create_proxy("call_function", print, (format_str,), proxy_kwargs)
@print.py_impl(FakeTensorMode)
# pyre-ignore
def print_fake_tensor_mode(mode, format_str: str, **kwargs: object):
with mode:
return None
@print.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
# pyre-ignore
def print_impl(format_str: str, **kwargs: object) -> None:
# Ensure all immutable_dict/list in kwargs are converted to regular dict/list
map_types: dict[type, type] = {
torch.fx.immutable_collections.immutable_dict: dict,
torch.fx.immutable_collections.immutable_list: list,
}
new_kwargs = pytree.tree_map_only(
tuple(map_types.keys()),
lambda a: map_types[type(a)](a),
kwargs,
lambda a: isinstance(a, tuple(map_types.keys())),
)
# Use built-in print to avoid recursion with the HOP print
builtins.print(format_str.format(**new_kwargs))
@print.py_autograd_impl
# pyre-ignore
def print_autograd(format_str: str, **kwargs: object):
# with torch._C._ExcludeDispatchKeyGuard(
# torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
# ):
return None
print.fallthrough(torch._C.DispatchKey.AutogradCPU)
print.fallthrough(torch._C.DispatchKey.AutogradCUDA)
@print.py_functionalize_impl
def print_func(ctx, format_str: str, **kwargs: object):
from torch._higher_order_ops.effects import handle_effects
return handle_effects(
ctx.mode._allow_token_discovery,
ctx.mode._tokens,
print,
(format_str,),
kwargs, # type: ignore[arg-type]
)
print.fallthrough(torch._C.DispatchKey.AutogradCPU)
print.fallthrough(torch._C.DispatchKey.AutogradCUDA)

View File

@ -103,6 +103,7 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
"dynamo_bypassing_wrapper", # TODO(soulitzer)
"foreach_map",
"aoti_call_delegate",
"print",
]
torch.library.define(
@ -153,6 +154,7 @@ def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs
def fn_for_invoke_subgraph(x):
return torch.sin(x)
def simple_invoke_subgraph(x):
return fn_for_invoke_subgraph(x)
@ -202,6 +204,7 @@ def simple_while_loop(iter_t, x):
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
def simple_while_loop_stack_output(iter_t, x):
def cond_fn(iter_t, x):
return iter_t > 0
@ -209,7 +212,9 @@ def simple_while_loop_stack_output(iter_t, x):
def body_fn(iter_t, x):
return iter_t - 1, x.cos()
return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple())
return torch._higher_order_ops.while_loop_stack_output(
cond_fn, body_fn, (iter_t, x), tuple()
)
def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
@ -226,18 +231,21 @@ def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs):
def simple_local_map_hop(inp1, inp2):
def body_gm(inp1, inp2):
return inp1.cos() + inp2.sin()
gm = torch.fx.symbolic_trace(body_gm)
assert torch.distributed.is_available()
from torch.distributed.tensor.placement_types import Replicate
gm.meta["local_map_kwargs"] = {
"in_placements": (Replicate(), Replicate(), Replicate()),
"out_placements": ((Replicate(), Replicate(), Replicate()),)
"out_placements": ((Replicate(), Replicate(), Replicate()),),
}
# TODO: Dynamo would rewrite this op differently
return torch._higher_order_ops.local_map_hop(gm, inp1, inp2)
def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(
make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
@ -249,7 +257,6 @@ def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs):
def simple_scan(init, xs):
def combine_fn(carry, x):
result = carry @ x + x
return result, carry.clone()
@ -264,15 +271,14 @@ def simple_invoke_quant(x):
def fn(x, y):
return (torch.sin(x) * y,)
return quant_tracer(fn, x, x)[0] * 2.
return quant_tracer(fn, x, x)[0] * 2.0
def simple_invoke_quant_packed(x):
def fn(x):
return (torch.sin(x),)
return invoke_quant_packed(fn, x)[0] * 2.
return invoke_quant_packed(fn, x)[0] * 2.0
hop_db = [
@ -496,6 +502,11 @@ hop_db = [
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
),
decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")],
decorators=[
onlyCUDA,
unittest.skipIf(
not torch.distributed.is_available(), "requires distributed build"
),
],
),
]