Compare commits

...

4 Commits

Author SHA1 Message Date
11ec651a7e Update
[ghstack-poisoned]
2025-11-09 22:13:08 -08:00
96f19e8a9e Update
[ghstack-poisoned]
2025-11-09 20:33:49 -08:00
877d854d8c Update
[ghstack-poisoned]
2025-11-09 20:08:15 -08:00
f7b3ffbd22 Update (base update)
[ghstack-poisoned]
2025-11-09 20:08:15 -08:00
6 changed files with 604 additions and 144 deletions

View File

@ -1681,14 +1681,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y, y)
return (y,)
""",
)
@ -1798,9 +1797,9 @@ class GraphModule(torch.nn.Module):
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
cos: "f32[4, 4]" = torch.cos(sin_1)
sin_2: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (cos, sin_2, matmul, o, out, sin_1)
""",
)

View File

@ -222,13 +222,13 @@ class GraphModule(torch.nn.Module):
matmul: "f32[3, 3]" = l_x_ @ l_y_
sin: "f32[3, 3]" = matmul.sin(); matmul = None
child: "f32[3, 3]" = sin.cos(); sin = None
cos: "f32[3, 3]" = sin.cos(); sin = None
child_1: "f32[3, 3]" = l_x_ + l_y_
child_2: "f32[3, 3]" = l_x_ - l_y_
add: "f32[3, 3]" = l_x_ + l_y_
sub: "f32[3, 3]" = l_x_ - l_y_
child_3: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (child, child_1, child_2, child_3)
matmul_1: "f32[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None
return (cos, add, sub, matmul_1)
""", # noqa: B950
)
self.assertExpectedInline(

View File

@ -249,7 +249,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# when testing with dynamic shape, symbols are lifted as input
arg_count = ifdynstaticdefault(2, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
@ -267,7 +267,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
# be the input.
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 1)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
@ -282,7 +282,7 @@ class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
x = torch.randn(3)
# when testing with dynamic shape, a symbol is lifted as input
arg_count = ifdynstaticdefault(3, 4)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 3)
self._test_wrap_simple(fn, default_args_generator((x,)), arg_count, 2)
def test_capture_untracked_global(self):
def f(x):
@ -762,15 +762,15 @@ class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s77, l_x_, u0, c); wrap_body_0 = s77 = l_x_ = u0 = c = None
child: "f32[s77]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[s77]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, s77: "Sym(s77)", l_x_: "f32[s77]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[s77]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[s77]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
else:
@ -801,15 +801,15 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, u0, c); wrap_body_0 = l_x_ = u0 = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", u0: "Sym(u0)", c: "i64[u0, 1]"):
child: "f32[3]" = l_x_.sin(); l_x_ = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
sin: "f32[3]" = l_x_.sin(); l_x_ = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (sin, sin_1)
""",
)
@ -922,16 +922,16 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_, size, c); wrap_body_0 = l_x_ = size = c = None
child: "f32[3]" = wrap[0]
child_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (child, child_1)
getitem: "f32[3]" = wrap[0]
getitem_1: "f32[u0, 1]" = wrap[1]; wrap = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]", size: "Sym(u0)", c: "i64[u0, 1]"):
sin: "f32[3]" = l_x_.sin(); l_x_ = None
child: "f32[3]" = sin + size; sin = size = None
child_1: "f32[u0, 1]" = c.sin(); c = None
return (child, child_1)
add: "f32[3]" = sin + size; sin = size = None
sin_1: "f32[u0, 1]" = c.sin(); c = None
return (add, sin_1)
""",
)
@ -2458,10 +2458,10 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"):
child: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None
child_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (child, child_1)
add_1: "f32[3]" = l_arg2_0_ + 1; l_arg2_0_ = None
return (add, add_1)
""",
)
@ -2655,9 +2655,9 @@ class GraphModule(torch.nn.Module):
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[2, 3]"):
child: "f32[2, 3]" = l_x_.sin()
child_1: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (child, child_1)
sin: "f32[2, 3]" = l_x_.sin()
cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None
return (sin, cos)
""",
)
@ -2687,13 +2687,13 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
value: "f32[3]" = wrap[0]; wrap = None
return (value,)
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
child: "f32[3]" = -l_x_; l_x_ = None
return (child,)
neg: "f32[3]" = -l_x_; l_x_ = None
return (neg,)
""",
)

View File

@ -899,14 +899,14 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None
child: "f32[8]" = mul * 2; mul = None
return (child,)
mul_1: "f32[8]" = mul * 2; mul = None
return (mul_1,)
class subgraph_1(torch.nn.Module):
def forward(self, a: "f32[8]", l_y_: "f32[8]"):
mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None
child: "f32[8]" = mul * 3; mul = None
return (child,)
mul_1: "f32[8]" = mul * 3; mul = None
return (mul_1,)
""",
)
@ -983,20 +983,20 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
subgraph_1 = self.subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, l_y_); subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
subgraph_2 = self.subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, l_y_); subgraph_2 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_1, l_y_); subgraph_2 = getitem_1 = None
getitem_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
subgraph_3 = self.subgraph_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_2, l_y_); subgraph_3 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_2, l_y_); subgraph_3 = getitem_2 = None
getitem_3: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
subgraph_4 = self.subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_3, l_y_); subgraph_4 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (x_4,)
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_3, l_y_); subgraph_4 = getitem_3 = l_y_ = None
getitem_4: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
return (getitem_4,)
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"):
@ -1495,9 +1495,9 @@ class GraphModule(torch.nn.Module):
class subgraph_0(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
child: "f32[8, 8]" = l_x_ * 2
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (child, child_1)
mul: "f32[8, 8]" = l_x_ * 2
mul_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (mul, mul_1)
""",
)

View File

@ -286,47 +286,31 @@ class GraphModule(torch.nn.Module):
l_self_modules_wo_parameters_weight_ = L_self_modules_wo_parameters_weight_
l_self_modules_w1_parameters_weight_ = L_self_modules_w1_parameters_weight_
l_self_modules_w2_parameters_weight_ = L_self_modules_w2_parameters_weight_
q: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wq_parameters_weight_, None); l_self_modules_wq_parameters_weight_ = None
k: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wk_parameters_weight_, None); l_self_modules_wk_parameters_weight_ = None
v: "f32[8, 16, 96]" = torch._C._nn.linear(l_x_, l_self_modules_wv_parameters_weight_, None); l_self_modules_wv_parameters_weight_ = None
unflatten: "f32[8, 16, 16, 6]" = q.unflatten(-1, (16, -1)); q = None
q_1: "f32[8, 16, 16, 6]" = unflatten.permute(0, 2, 1, 3); unflatten = None
unflatten_1: "f32[8, 16, 16, 6]" = k.unflatten(-1, (16, -1)); k = None
k_1: "f32[8, 16, 16, 6]" = unflatten_1.permute(0, 2, 1, 3); unflatten_1 = None
unflatten_2: "f32[8, 16, 16, 6]" = v.unflatten(-1, (16, -1)); v = None
v_1: "f32[8, 16, 16, 6]" = unflatten_2.permute(0, 2, 1, 3); unflatten_2 = None
subgraph_0 = self.subgraph_0
local_map_hop = torch.ops.higher_order.local_map_hop(subgraph_0, q_1, k_1, v_1); subgraph_0 = q_1 = k_1 = v_1 = None
o: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = o.permute(0, 2, 1, 3); o = None
o_1: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_2: "f32[8, 16, 96]" = torch._C._nn.linear(o_1, l_self_modules_wo_parameters_weight_, None); o_1 = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_2 + l_x_; o_2 = l_x_ = None
o_3: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_4: "f32[8, 16, 384]" = torch.nn.functional.relu(o_3); o_3 = None
o_5: "f32[8, 16, 96]" = torch._C._nn.linear(o_4, l_self_modules_w2_parameters_weight_, None); o_4 = l_self_modules_w2_parameters_weight_ = None
o_6: "f32[8, 16, 96]" = o0 + o_5; o0 = o_5 = None
return (o_6,)
getitem: "f32[8, 16, 16, 6]" = local_map_hop[0]; local_map_hop = None
permute_3: "f32[8, 16, 16, 6]" = getitem.permute(0, 2, 1, 3); getitem = None
o: "f32[8, 16, 96]" = permute_3.flatten(-2); permute_3 = None
o_1: "f32[8, 16, 96]" = torch._C._nn.linear(o, l_self_modules_wo_parameters_weight_, None); o = l_self_modules_wo_parameters_weight_ = None
o0: "f32[8, 16, 96]" = o_1 + l_x_; o_1 = l_x_ = None
o_2: "f32[8, 16, 384]" = torch._C._nn.linear(o0, l_self_modules_w1_parameters_weight_, None); l_self_modules_w1_parameters_weight_ = None
o_3: "f32[8, 16, 384]" = torch.nn.functional.relu(o_2); o_2 = None
o_4: "f32[8, 16, 96]" = torch._C._nn.linear(o_3, l_self_modules_w2_parameters_weight_, None); o_3 = l_self_modules_w2_parameters_weight_ = None
o_5: "f32[8, 16, 96]" = o0 + o_4; o0 = o_4 = None
return (o_5,)
class subgraph_0(torch.nn.Module):
def forward(self, q_1: "f32[1, 2, 4, 6]", k_1: "f32[1, 2, 16, 6]", v_1: "f32[1, 2, 16, 6]"):
out: "f32[1, 2, 4, 6]" = torch._C._nn.scaled_dot_product_attention(query = q_1, key = k_1, value = v_1, is_causal = False); q_1 = k_1 = v_1 = None
return (out,)
""",
return (out,)""",
ignore_empty_lines=True,
)

View File

@ -28,7 +28,7 @@ import types
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING
import torch._C
import torch.fx
@ -215,7 +215,10 @@ def find_mismatched_vars(var, types, allow_none=False):
A set of variables whose type is not an instance of the specified types.
"""
mismatched_vars = set()
if isinstance(var, (TupleVariable, ListVariable)):
if isinstance(var, (list, tuple)):
for item in var:
mismatched_vars.update(find_mismatched_vars(item, types, allow_none))
elif isinstance(var, (TupleVariable, ListVariable)):
for item in var.items:
mismatched_vars.update(find_mismatched_vars(item, types, allow_none))
elif isinstance(var, ConstDictVariable):
@ -248,6 +251,78 @@ def _make_inlined(tx: "InstructionTranslator", f):
return inline_call
def _call_function_and_unflatten_output_wrap_semantics(
tx: "InstructionTranslator",
fn: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
flat_example_value: Any,
body_r: Optional[VariableTracker],
graph_output_vts: VariableTracker | tuple[VariableTracker, ...],
) -> Optional[VariableTracker]:
"""
Create HOP call node and reproxify output VTs for HOPs with wrap semantics.
This function is used by HOPs with wrap semantics (see speculate_subgraph_with_wrap_semantics)
to create the actual HOP call in the FX graph and properly handle the output variable trackers.
The key operation is "reproxifying" - updating the proxies of the original tensor VTs
(from body_r) to point to the HOP call outputs, ensuring the outer graph correctly
references the HOP outputs while allowing body_r to contain arbitrary Python objects.
Args:
tx: The instruction translator
fn: The HOP function to call
args: Arguments for the HOP call (typically includes the subgraph node)
kwargs: Keyword arguments for the HOP call
flat_example_value: Example value for the HOP output
body_r: The output VT structure that Dynamo continues tracing with (may be None)
graph_output_vts: Tensor/symint VTs that were actual graph outputs
Returns:
The body_r VT (unchanged), which Dynamo will continue tracing with
"""
from .builder import wrap_fx_proxy
# Store the invocation as a call
flat_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn,
args=args,
kwargs=kwargs,
),
example_value=flat_example_value,
)
# wrap_fx_proxy creates fresh variable trackers. However, the main program
# after the speculate subgraph can still use the original tensor vts that
# are still pointing to the nodes present in the subgraph. So, we reproxify
# the original tensor vts with the subgraph outputs. This way, whenever the
# outer graph uses an original vt, it uses the subgraph output.
#
# This is critical for maintaining the separation between:
# - `body_r`: The output VT structure that Dynamo continues tracing (may
# contain non-proxyable objects, nested structures, etc.)
# - `graph_output_vts`: Only the tensor/symint VTs that were actual graph
# outputs from speculate_subgraph
#
# By overwriting the proxies of VTs in `body_r` with the proxies from the
# HOP call, we ensure the outer graph correctly references the HOP outputs
# while still allowing `body_r` to contain arbitrary Python objects.
if body_r is not None:
for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items):
if isinstance(
orig_vt, (variables.SymNodeVariable, variables.TensorVariable)
):
assert isinstance(
subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable)
)
orig_vt.proxy = subgraph_vt.proxy
return body_r
def _call_function_and_unflatten_output(
tx, fn, args, kwargs, flat_example_value, ret_spec, body_r
):
@ -935,6 +1010,417 @@ def _merge_graph_inputs(
return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r
def speculate_subgraph_with_wrap_semantics(
tx: "InstructionTranslator",
f: VariableTracker,
sub_args: Sequence[VariableTracker],
sub_kwargs: Optional[dict[str, VariableTracker]],
description: str,
*,
# source_target is the .value of HigherOrderOpVariable and is the
# target of the proxy that we created for the higherOrderOperator.
source_target: Optional[HigherOrderOperator] = None,
enable_grad: Optional[bool] = None,
# TODO - We can probably just make everyone use automatic for wrap_semantics
set_subgraph_inputs: Literal[
"automatic", "semi_automatic", "flatten_manual", "manual"
] = "automatic",
# Make default False
restore_side_effects: bool = True,
under_activation_checkpoint: bool = False,
# TODO - supports input_mutation and aliasing should be False by default for strictness
supports_input_mutation: bool = True,
supports_aliasing: bool = True,
# Pass in an originating tracer - this is needed for preserving context
# across fwd-bwd for autograd.Function
tracer: Optional["torch._dynamo.output_graph.SubgraphTracer"] = None,
) -> tuple[
VariableTracker, # output: The VT that Dynamo continues tracing with
torch.fx.Graph, # graph: The FX graph representing the subgraph computation
dict[
torch.fx.Proxy, torch.fx.Proxy
], # lifted_freevars: Free variables lifted as inputs
VariableTracker
| tuple[
VariableTracker, ...
], # graph_output_vts: Tensor/symint VTs that are actual FX graph outputs
]:
"""
Speculate subgraph for Higher-Order Operators (HOPs) with wrap semantics.
## Wrap Semantics
Some HOPs have "wrap semantics", meaning the HOP at runtime essentially just runs
the subgraph with the inputs. For example:
- invoke_subgraph
- activation checkpointing (torch.utils.checkpoint.checkpoint)
- autograd.Function
- nested_compile_region
This is in contrast to control flow HOPs which do NOT follow wrap semantics:
- torch.cond (conditional execution based on predicate)
- torch.while_loop (iterative execution)
- torch.map (parallel execution over batch dimension)
For control flow HOPs, the HOP behavior is fundamentally different from just
running the body function once.
## Key Advantage: Disentangling VTs from Graph Outputs
Wrap semantics simplify HOP processing by allowing us to disentangle the output
variable trackers (VTs) from the HOP subgraph outputs. This mirrors typical
Dynamo processing where:
- VTs "run ahead" representing the program state for continued tracing
- The graph is a side data structure tracking computation seen so far
This separation is crucial for HOPs with non-proxyable outputs (e.g., custom
user-defined objects containing tensors). The function may return complex Python
objects for Dynamo to continue tracing, but only the tensor/symint VTs need to
be registered as actual FX graph outputs.
Example:
class Foo:
def __init__(self, a, b):
self.a = a # tensor
self.b = b # tensor
def gn(x):
return Foo(torch.sin(x), torch.cos(x))
result = some_hop(gn, x) # Returns Foo instance
out = result.a + result.b # Dynamo can continue tracing
Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but
`graph_output_vts` contains only the tensor VTs (a and b) that should be
actual FX graph outputs. This allows Dynamo to continue tracing with the
Foo object while the graph only needs to output the constituent tensors.
## Return Values
Unlike `speculate_subgraph`, this function returns:
- output: The VT that Dynamo continues tracing with (may be complex Python objects)
- graph: The FX graph representing the subgraph computation
- lifted_freevars: Free variables lifted as inputs to the subgraph
- graph_output_vts: Only the tensor/symint VTs that are actual FX graph outputs
The key difference is `graph_output_vts` instead of `treespec`, which gives more
flexibility for handling non-proxyable outputs.
"""
if sub_kwargs is None:
sub_kwargs = {}
assert set_subgraph_inputs in {
"automatic",
"semi_automatic",
"flatten_manual",
"manual",
}, "Please use one of the supported set_subgraph_inputs options."
# See NOTE [Temporary argument `set_subgraph_inputs`]
if sub_kwargs and set_subgraph_inputs != "automatic":
unimplemented(
gb_type="invalid set_subgraph_inputs and sub_kwargs settings",
context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}",
explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.",
hints=[
"Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.",
*graph_break_hints.USER_ERROR,
],
)
try:
# ensure guards on args get installed in parent subgraph
f, sub_args, sub_kwargs = LazyVariableTracker.realize_all(
(f, sub_args, sub_kwargs),
)
with tx.output.subtracer(source_target, tracer) as subtracer:
sub_args_names = maybe_positional_arg_names(f)
# User mismatch in the number of args. Will eventually lead to an error.
if sub_args_names is not None and len(sub_args_names) < len(sub_args):
sub_args_names = None
args = validate_args_and_maybe_create_graph_inputs(
sub_args,
subtracer,
tx,
set_subgraph_inputs,
description,
sub_args_names,
)
validate_args_and_maybe_create_graph_inputs(
sub_kwargs.values(),
subtracer,
tx,
set_subgraph_inputs="automatic",
description=description,
)
autograd_ctx = (
dynamo_enable_grad(tx, enable_grad)
if enable_grad is not None
else contextlib.nullcontext()
)
checkpoint_ctx = (
dynamo_under_activation_checkpoint(tx)
if under_activation_checkpoint
else contextlib.nullcontext()
)
# For handling side effects, we can make an argument that we don't
# have to do anything here. The side effects infra does a good job
# of graph breaking if we mutate any nonlocal or global variable
# while subtracing. As a result if tracing succeeds, side effects
# data structure will only contain read-only data structures that
# are put there for tracking purposes.
# But on the other hand, there is an argument that if we ever write
# a new side effect in Dynamo which does not go through the side
# effect infra, we can end up in bad state.
# Therefore we restore the side effects after tracing. The catch is
# that we have to special handle tensor variables. If we have seen a
# nonlocal variable tensor during subtracing, we want to keep a
# track of that tensor, so that later subtracing or the root tracer
# itself does not create a new proxy for the already observed tensor
# variable.
if restore_side_effects:
prev_side_effects = tx.output.side_effects.clone()
with autograd_ctx, checkpoint_ctx:
output = f.call_function(tx, args, sub_kwargs)
if restore_side_effects:
new_side_effects = tx.output.side_effects.clone()
prev_side_effects.track_runahead_tensor_and_symvar_side_effects(
new_side_effects
)
tx.output.side_effects = prev_side_effects
# NOTE: [Separation of graph outputs and output VTs]
# In Dynamo (outside of speculate_subgraph), VTs and the graph are
# separate concepts:
# - VTs (VariableTrackers) can "run ahead" and continue Dynamo tracing
# - The graph is just a side data structure tracking computation seen so far
#
# This separation is crucial for HOPs with non-proxyable outputs (e.g.,
# custom user-defined objects containing tensors). The function may return
# complex Python objects for Dynamo to continue tracing, but only the
# tensor/symint VTs need to be registered as actual graph outputs.
#
# Example:
# class Foo:
# def __init__(self, a, b):
# self.a = a # tensor
# self.b = b # tensor
#
# def gn(x):
# return Foo(torch.sin(x), torch.cos(x))
#
# Here, `output` VT is a UserDefinedObjectVariable wrapping Foo, but
# `graph_output_vts` contains only the tensor VTs (a and b) that should
# be actual FX graph outputs.
# Collect only tensor and symint VTs that should be graph outputs.
# We walk the output structure and extract proxyable VTs.
graph_output_vts = []
output_types = (variables.TensorVariable, variables.SymNodeVariable)
def visit(vt):
if isinstance(vt, output_types):
graph_output_vts.append(vt)
VariableTracker.visit(visit, output)
graph_output_vts = tuple(graph_output_vts)
# NOTE - [Return subgraph intermediates as subgraph outputs]
# This helps HOPs which allow side effects. Consider the
# following example
#
# def gn(x, z):
# o = torch.matmul(x, x) @ x
# out = x.sin()
# z.append(out)
# return torch.cos(torch.sin(o))
# def fn(x):
# z = []
# out1 = torch.utils.checkpoint.checkpoint(
# gn,
# x,
# z,
# use_reentrant=False,
# )
# return out1, z[0]
#
# In this example, list `z` is in outer scope and gets appended
# in the subgraph with `out`. But `out` is not an output of the
# subgraph. This can cause issue because later on when the outer
# graph returns `z[0]` it needs to have access to the graph node
# `out`. To solve this problem, we just return all intermediates
# from the subgraph.
# TODO - Today this is supported only for AC. AC HOP gets
# desugared in AOTDispatcher so even though subgraph has extra
# unused outputs in Dynamo, its ok even if we don't DCE them in
# Dynamo. As AOTDispatcher desugars/inlines the subgraph, the
# subgraph boundary disappears. And even for AC, today this only
# works when the skip_fwd_side_effects_in_bwd_under_checkpoint
# flag is True, i.e., only when we allow side-effects. But, we
# want this to be supported for other Hops as well, specifically
# nested_compile_region and autograd.Function. Today, its safe
# because we error out on seeing a side-effect.
if under_activation_checkpoint:
extra_outputs = []
for out in subtracer.tracked_tensor_or_symint_vt:
if out not in set(graph_output_vts):
extra_outputs.append(out)
graph_output_vts = graph_output_vts + tuple(extra_outputs)
validate_subgraph_output_types(graph_output_vts)
# The output proxies might not belong to this SubgraphTracer
# (if they are free variables that were never lifted)
# so lift them here.
# output_proxies = output.as_proxy()
if isinstance(graph_output_vts, tuple):
output_proxies = [a.as_proxy() for a in graph_output_vts]
output_proxies = pytree.tree_map(
subtracer.maybe_lift_tracked_freevar_to_input, output_proxies
)
output_proxies = tuple(output_proxies)
else:
output_proxies = output.as_proxy()
output_proxies = pytree.tree_map(
subtracer.maybe_lift_tracked_freevar_to_input, output_proxies
)
tx.output.create_node(
"output",
"output",
(subtracer.create_arg((output_proxies,))),
{},
)
graph = tx.output.graph
graph.lint()
lifted_freevars = subtracer.lifted_freevars
# NOTE: [HigherOrderOperator subgraph input ordering]
# The input ordering of the higher order ops is determined by the order of
# the creation of the placeholder.
# Manually created inputs are created in validate_args_and_maybe_create_graph_inputs before
# speculating subgraph.
# During subgraph speculation, we may lift closured tensors and free symbols as inputs,
# their ordering is determined by the time they are lifted: earlier lifted ones precede later
# lifted ones.
#
# Suppose the placeholders are
# O1, O2, X1, O3, O4, X2, X3, O5 where Xs are lifted phs
# The following code re-order the placeholders to
# O1, O2, O3, O4, O5, X1, X2, X3
def move_lifted_freevars_phs_to_end(
graph: torch.fx.Graph, lifted_freevars: tuple[torch.fx.Node]
):
lifted_ph_set = {child_p.node for child_p in lifted_freevars.values()}
prev_phs = [n for n in graph.nodes if n.op == "placeholder"]
# No need to reorder when graph doesn't have args or doesn't
# have lifted freevars or all inputs are lifted freevars.
if (
len(prev_phs) == 0
or len(lifted_ph_set) == 0
or len(prev_phs) == len(lifted_ph_set)
):
return
# Step 1: find first X1
for x1 in prev_phs:
if x1 in lifted_ph_set:
break
assert x1 is not None and x1.op == "placeholder"
# Step 2: starting from the X1, skip Xs and prepend Os before X1.
cand_x = x1.next
while cand_x is not None and cand_x.op == "placeholder":
if cand_x in lifted_ph_set:
cand_x = cand_x.next
else:
nxt = cand_x.next
cand_x._remove_from_list()
x1.prepend(cand_x)
cand_x = nxt
# Step 3: assert that all placeholders are in the correct order as .
# in lifted_freevars
after_phs = [node for node in graph.nodes if node.op == "placeholder"][
-len(lifted_freevars) :
]
assert len(after_phs) == len(lifted_freevars)
for child_proxy, ph in zip(lifted_freevars.values(), after_phs):
assert child_proxy.node is ph, (
"The order of placeholders is different from the order of lifted_freevars"
)
graph.lint()
if len(lifted_freevars) > 0:
move_lifted_freevars_phs_to_end(graph, lifted_freevars)
if not supports_input_mutation:
mutation_info = subtracer.has_input_mutation()
if mutation_info.has_mutation:
context = f"{mutation_info.msg} in\n {graph}"
unimplemented(
gb_type="Encountered input mutation during higher order op tracing",
context=context,
explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}",
hints=[
"Consider using the debug context to change user code to avoid mutation.",
"Please open an issue.",
],
)
if not supports_aliasing:
aliasing_info = subtracer.has_aliasing()
if aliasing_info.has_aliasing:
context = f"{aliasing_info.msg} in\n {graph}"
unimplemented(
gb_type="Encountered aliasing during higher order op tracing",
context=context,
explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}",
hints=[
"Replace `return input` with `return input.clone()` to avoid aliasing.",
"Consider using the debug context to change user code to avoid aliasing.",
"Please open an issue.",
],
)
# Return both the output VT and the graph output VTs separately:
# - `output`: The VT that Dynamo continues tracing with (may be
# complex Python objects, tuples, dicts, etc.)
# - `graph`: The FX graph representing the subgraph computation
# - `lifted_freevars`: Free variables lifted as inputs to the subgraph
# - `graph_output_vts`: Only the tensor/symint VTs that are actual
# FX graph outputs (basically the vts associated with graph outputs)
return (
output,
graph,
lifted_freevars,
graph_output_vts,
)
except Unsupported as ex:
f_name = f"{type(f).__name__}"
if isinstance(f, UserFunctionVariable):
f_name = f.get_name()
msg = (
f"speculate_subgraph: while introspecting {description}, we were unable "
f"to trace function `{f_name}` into a single graph. This means "
f"that Dynamo was unable to prove safety for this API and will "
f"fall back to eager-mode PyTorch, which could lead to a slowdown."
)
log.info(msg)
log.info(ex) # noqa: G200
raise ex
# See NOTE [HigherOrderOperator tracing design] for details of the design
def speculate_subgraph(
tx,
@ -2439,10 +2925,11 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
# See NOTE [HigherOrderOperator tracing design] for more details
(
(body_r, treespec),
body_r,
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
body_graph_output_vts,
) = speculate_subgraph_with_wrap_semantics(
tx,
fn_vt,
fn_args_vt,
@ -2450,7 +2937,6 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
description,
source_target=self.value,
restore_side_effects=self.restore_side_effects,
should_flatten_outputs=True,
under_activation_checkpoint=under_activation_checkpoint,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
@ -2472,13 +2958,22 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
lifted_args = tuple(arg for arg in body_lifted_freevars)
proxy_args = (body_node,) + lifted_args
example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
torch.fx.Node,
lambda a: a.meta["example_value"],
body_graph.find_nodes(op="output")[0].args[0],
)
return proxy_args, {}, example_value, body_r, treespec, body_gmod, body_name
return (
proxy_args,
{},
example_value,
body_r,
body_gmod,
body_name,
body_graph_output_vts,
)
def _call_function(
self,
@ -2492,9 +2987,9 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
p_kwargs,
_example_value,
body_r,
treespec,
_,
_,
body_graph_output_vts,
) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap")
if len(p_kwargs) > 0:
@ -2507,20 +3002,14 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
],
)
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
return _call_function_and_unflatten_output(
return _call_function_and_unflatten_output_wrap_semantics(
tx,
self.value,
tuple(p_args),
p_kwargs,
flat_example_value,
treespec,
_example_value,
body_r,
body_graph_output_vts,
)
@ -2939,9 +3428,9 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
_,
example_value,
_body_r,
out_spec,
checkpointed_gmod,
_,
body_graph_output_vts,
) = self.create_wrapped_node(
tx,
args[0],
@ -2955,14 +3444,14 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
_, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs)
return _call_function_and_unflatten_output(
return _call_function_and_unflatten_output_wrap_semantics(
tx,
self.value,
p_args,
checkpoint_kwargs,
example_value,
out_spec,
_body_r,
body_graph_output_vts,
)
@ -2993,9 +3482,9 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
_,
example_value,
_body_r,
out_spec,
gmod,
_,
body_graph_output_vts,
) = self.create_wrapped_node(
tx,
args[1],
@ -3009,14 +3498,14 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
gmod_meta_key = "_dynamo_bypassing_wrapper_fn"
gmod.meta[gmod_meta_key] = func
return _call_function_and_unflatten_output(
return _call_function_and_unflatten_output_wrap_semantics(
tx,
self.value,
(gmod_meta_key,) + tuple(p_args),
{},
example_value,
out_spec,
_body_r,
body_graph_output_vts,
)
@ -3745,29 +4234,23 @@ class BaseHOPVariable(WrapHigherOrderVariable):
p_kwargs,
example_value,
body_r,
treespec,
body_gmod,
body_name,
_,
_,
body_graph_output_vts,
) = self.create_wrapped_node(
tx, args[0], args[1:], {}, self.value._name, subgraph_name="subgraph"
)
assert len(p_kwargs) == 0
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
return _call_function_and_unflatten_output(
return _call_function_and_unflatten_output_wrap_semantics(
tx,
self.value,
p_args,
p_kwargs,
flat_example_value,
treespec,
example_value,
body_r,
body_graph_output_vts,
)
@ -3850,9 +4333,9 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
p_kwargs,
example_value,
body_r,
treespec,
body_gmod,
_,
body_name,
body_graph_output_vts,
) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph")
if len(p_kwargs) > 0:
@ -3865,25 +4348,19 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
],
)
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
p_args = (
p_args[0],
body_name,
*p_args[1:],
)
return _call_function_and_unflatten_output(
return _call_function_and_unflatten_output_wrap_semantics(
tx,
torch._higher_order_ops.invoke_subgraph,
tuple(p_args),
p_kwargs,
flat_example_value,
treespec,
example_value,
body_r,
body_graph_output_vts,
)
@ -4038,9 +4515,9 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
p_kwargs,
example_value,
body_r,
treespec,
body_gmod,
body_name,
body_graph_output_vts,
) = self.create_wrapped_node(
tx, user_func, user_args, kwargs, self.value._name, subgraph_name="subgraph"
)
@ -4098,16 +4575,16 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
)
assert len(p_kwargs) == 0
flat_example_value = pytree.tree_map_only(
torch.fx.Proxy,
lambda a: a.node.meta["example_value"],
body_r.as_proxy(),
)
# Step 5: Install local_map subgraph
p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()}
out = _call_function_and_unflatten_output(
tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r
out = _call_function_and_unflatten_output_wrap_semantics(
tx,
self.value,
p_args,
p_kwargs,
example_value,
body_r,
body_graph_output_vts,
)
# Step 6: Restore inputs and outputs to global shapes