[dynamo][hops] Remove const outputs from the speculated subgraph (#161355)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161355
Approved by: https://github.com/zou3519
This commit is contained in:
Animesh Jain
2025-08-28 11:32:55 -07:00
committed by PyTorch MergeBot
parent 9480cdc0b6
commit 6b1900c22f
6 changed files with 167 additions and 73 deletions

View File

@ -2608,25 +2608,17 @@ class GraphModule(torch.nn.Module):
f, default_args_generator((x,)), arg_count, expected_opcount=3
)
def test_fallback_on_python_primitives_output(self):
def test_support_float_in_output(self):
counters.clear()
cnt = CompileCounter()
@torch.compile(backend=cnt)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return wrap(lambda x: [1, torch.sin(x), 2.0], x)
x = torch.randn(3)
result = f(x)
self.assertEqual(result, [1, torch.sin(x), 2.0])
self.assertEqual(cnt.frame_count, 0)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{
".*HigherOrderOperator body's output must consist of tensors or ints only but got": 1
},
)
def test_nested_tuple_output(self):
def f(x):

View File

@ -1498,7 +1498,7 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = l_x_ = None
getitem: "f32[8, 8]" = invoke_subgraph[0]
getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None
getitem_1: "f32[8, 8]" = invoke_subgraph[1]; invoke_subgraph = None
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
return (add,)
@ -1507,7 +1507,7 @@ class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[8, 8]"):
child: "f32[8, 8]" = l_x_ * 2
child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None
return (child, None, child_1)
return (child, child_1)
""",
)
@ -1520,16 +1520,16 @@ class GraphModule(torch.nn.Module):
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1); partitioned_fw_subgraph_0_0 = primals_1 = None
getitem: "f32[8, 8]" = invoke_subgraph_2[0]
getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None
getitem_1: "f32[8, 8]" = invoke_subgraph_2[1]; invoke_subgraph_2 = None
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_2); getitem = getitem_2 = None
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return (add,)
class partitioned_fw_subgraph_0_0(torch.nn.Module):
def forward(self, primals_0: "f32[8, 8]"):
mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 2)
mul_1: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 3); primals_0 = None
return (mul, None, mul_1)
return (mul, mul_1)
""",
)
@ -1541,8 +1541,8 @@ class GraphModule(torch.nn.Module):
partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', tangents_1, tangents_1); partitioned_bw_subgraph_0_0 = tangents_1 = None
getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
return (getitem_3,)
getitem_2: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
return (getitem_2,)
class partitioned_bw_subgraph_0_0(torch.nn.Module):
def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"):
@ -1888,6 +1888,37 @@ class GraphModule(torch.nn.Module):
""",
)
def test_return_size(self):
def run(dynamic):
torch.compiler.reset()
@nested_compile_region
def gn(x):
y = x + 1
z = x.shape
return y, z
def fn(x):
z0 = gn(x)
z1 = gn(x)
return z0[0] + z1[0], z0[1]
x = torch.randn(8, 8, requires_grad=True)
x_clone = x.detach().clone().requires_grad_(True)
ref = fn(x)
opt_fn = torch.compile(
fn, backend="inductor", fullgraph=True, dynamic=dynamic
)
res = opt_fn(x_clone)
self.assertEqual(ref, res)
ref[0].sum().backward()
res[0].sum().backward()
self.assertEqual(x.grad, x_clone.grad)
run(dynamic=True)
run(dynamic=False)
def test_different_symint(self):
"""
Tests check that the same subgraph called with different symints use different graphs

View File

@ -248,3 +248,33 @@ def wrap_inline_with_error_on_graph_break(
return fn(*args, **kwargs)
return wrapper
def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]:
"""
masks is a list of bools, where True means the corresponding element in tup
is a const value. Filter out the const values.
"""
out = []
for mask_idx, mask in enumerate(masks):
if not mask:
out.append(tup[mask_idx])
return tuple(out)
def insert_const_values_with_mask(
tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...]
) -> tuple[Any, ...]:
"""
masks and values are of same length. For indices where the mask is True, use
the const_values to fill in.
"""
out = []
idx = 0
for mask_idx, mask in enumerate(masks):
if mask:
out.append(values[mask_idx])
else:
out.append(tup[idx])
idx += 1
return tuple(out)

View File

@ -28,7 +28,7 @@ import types
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import torch._C
import torch.fx
@ -74,10 +74,28 @@ hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
@dataclass
class OutputSpec:
"""
The treespec of the output of the speculated subgraph and other metadata.
Contains the treespec of the output of the speculated subgraph, and the
information to mask out the constant values from the output during
flattening and inserting them back during unflattening. Cleaning up
constants from the graph makes the graph simpler for AOTDispatcher and
Inductor.
"""
treespec: pytree.TreeSpec
# list of True/False to identify the locations of const values in the
# subgraph output. True means that value at that index is a constant.
masks_to_filter_const_values: Optional[list[bool]] = None
# The actual constant values that were present in the subgraph output. Note
# that this is the same length as the mask, we just look at the indices
# where mask is True.
const_values: Optional[list[Any]] = None
def __post_init__(self):
if (
self.masks_to_filter_const_values is not None
or self.const_values is not None
):
assert len(self.masks_to_filter_const_values) == len(self.const_values)
def raise_hard_error_if_graph_break(reason):
@ -242,6 +260,15 @@ def _call_function_and_unflatten_output(
example_value=flat_example_value,
)
if ret_spec.masks_to_filter_const_values:
from torch._dynamo.external_utils import insert_const_values_with_mask
# During flattening, we removed the constant values. To ensure Dynamo
# can trace correctly, insert back the constant values in the output.
flat_variable = _make_inlined(tx, insert_const_values_with_mask)(
flat_variable, ret_spec.masks_to_filter_const_values, ret_spec.const_values
)
# Transform variable back into a list (previously made into a tuple by
# speculate_subgraph function) so as to respect the pytree API typing.
flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {})
@ -646,6 +673,9 @@ def speculate_subgraph(
set_subgraph_inputs="automatic",
restore_side_effects=True,
should_flatten_outputs=False,
# if should_flatten_outputs is True, `remove_consts_from_outputs` remove the
# const outputs from the subgraph output.
remove_consts_from_outputs=True,
under_activation_checkpoint=False,
# TODO - supports input_mutation and aliasing should be False by default for strictness
supports_input_mutation=True,
@ -736,15 +766,38 @@ def speculate_subgraph(
tx.output.side_effects = prev_side_effects
treespec = None
masks_to_filter_const_values = None
const_values = None
if should_flatten_outputs:
from torch._dynamo.external_utils import filter_out_const_values
# Flatten the speculated subgraph output.
output, treespec = _make_inlined(tx, pytree.tree_flatten)(
output
).unpack_var_sequence(tx)
# Actually, transform the list (returned by flatten) into a tuple
# for dynamo consistency.
output = BuiltinVariable(tuple).call_function(tx, [output], {})
if remove_consts_from_outputs:
# Filter out the constants and save them into a spec. Filtering
# out constants makes the graph simpler for the backends. We
# need to ensure that after unflattening the constants are
# inserted back at the right positions for the Dynamo tracing to
# continue. This is done by filter_const_spec
output_proxies = output.as_proxy()
masks_to_filter_const_values = pytree.tree_map(
lambda x: not isinstance(x, torch.fx.Proxy), output_proxies
)
const_values = pytree.tree_map(
lambda x: None if isinstance(x, torch.fx.Proxy) else x,
output_proxies,
)
output = _make_inlined(tx, filter_out_const_values)(
output, masks_to_filter_const_values
)
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support pytree output
@ -753,7 +806,12 @@ def speculate_subgraph(
if always_restore:
# Nothing left to do here
return (
(output, OutputSpec(treespec)),
(
output,
OutputSpec(
treespec, masks_to_filter_const_values, const_values
),
),
tx.output.graph,
subtracer.lifted_freevars,
)
@ -872,7 +930,12 @@ def speculate_subgraph(
)
return (
(output, OutputSpec(treespec)),
(
output,
OutputSpec(
treespec, masks_to_filter_const_values, const_values
),
),
graph,
lifted_freevars,
)
@ -1070,6 +1133,8 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
"cond",
source_target=self.value,
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
@ -1381,6 +1446,8 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=False,
supports_aliasing=False,
)
@ -1926,6 +1993,8 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
source_target=self.value,
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
# TODO - removing consts from control flow ops need more work
remove_consts_from_outputs=False,
supports_input_mutation=self.supports_input_mutation,
supports_aliasing=self.supports_aliasing,
)
@ -2479,8 +2548,6 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
from torch._higher_order_ops.wrap import TagActivationCheckpoint
from torch.utils.checkpoint import noop_context_fn
from .builder import wrap_fx_proxy
context_fn = None
if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn:
ctx = kwargs.pop("context_fn")
@ -2520,27 +2587,15 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
_, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs)
# Store the invocation as a call
variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs=checkpoint_kwargs,
),
example_value=example_value,
return _call_function_and_unflatten_output(
tx,
self.value,
p_args,
checkpoint_kwargs,
example_value,
out_spec,
)
if out_spec is None:
return variable
# Transform variable back into a list (previously made into a tuple by
# speculate_subgraph function) so as to respect the pytree API typing.
variable = BuiltinVariable(list).call_function(tx, [variable], {})
return _make_inlined(tx, pytree.tree_unflatten)(variable, out_spec.treespec)
class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
def __init__(self, hop, source) -> None:
@ -2552,8 +2607,6 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from .builder import wrap_fx_proxy
func_var = args[0]
if isinstance(func_var, torch._dynamo.variables.UserFunctionVariable):
@ -2571,7 +2624,7 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
_,
example_value,
_body_r,
treespec,
out_spec,
gmod,
_,
) = self.create_wrapped_node(
@ -2587,27 +2640,15 @@ class DynamoBypassingWrapperHigherOrderVariable(WrapHigherOrderVariable):
gmod_meta_key = "_dynamo_bypassing_wrapper_fn"
gmod.meta[gmod_meta_key] = func
# Store the invocation as a call
variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=(gmod_meta_key,) + tuple(p_args),
kwargs={},
),
example_value=example_value,
return _call_function_and_unflatten_output(
tx,
self.value,
(gmod_meta_key,) + tuple(p_args),
{},
example_value,
out_spec,
)
if treespec is None:
return variable
# Transform variable back into a list (previously made into a tuple by
# speculate_subgraph function) so as to respect the pytree API typing.
variable = BuiltinVariable(list).call_function(tx, [variable], {})
return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec.treespec)
class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(

View File

@ -45,7 +45,7 @@ invoke_subgraph_counter = 0
@dataclass
class OutputMetadata:
num_fw_outs: Optional[int] = None
indexes_with_none: set[int] = field(default_factory=set)
indexes_with_symint: set[int] = field(default_factory=set)
indexes_with_no_grad: set[int] = field(default_factory=set)
@ -258,8 +258,8 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
output_metadata.num_fw_outs = num_fw_outs
for idx, fw_out in enumerate(fw_outs):
if fw_out is None:
output_metadata.indexes_with_none.add(idx)
if isinstance(fw_out, torch.SymInt):
output_metadata.indexes_with_symint.add(idx)
elif not fw_out.requires_grad:
output_metadata.indexes_with_no_grad.add(idx)
@ -331,8 +331,8 @@ def get_output_metadata(subgraph, *operands):
output_metadata.num_fw_outs = num_fw_outs
for idx, fw_out in enumerate(fw_outs):
if fw_out is None:
output_metadata.indexes_with_none.add(idx)
if isinstance(fw_out, torch.SymInt):
output_metadata.indexes_with_symint.add(idx)
elif not fw_out.requires_grad:
output_metadata.indexes_with_no_grad.add(idx)
return output_metadata
@ -428,10 +428,10 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
*operands,
)
# Check that None is at expected indexes.
# Check that int (coming from symint) is at expected indexes.
for idx, o in enumerate(out):
if o is None:
assert idx in output_metadata.indexes_with_none
if isinstance(o, int):
assert idx in output_metadata.indexes_with_symint
return out
@ -452,7 +452,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
filtered_grad_outs = []
for idx, o in enumerate(grad_outs):
if o is None:
assert idx in output_metadata.indexes_with_none
assert idx in output_metadata.indexes_with_symint
elif idx in output_metadata.indexes_with_no_grad:
# Deliberately skip over the grad_outs which we know should be
# None because the corresponding fwd_out does not require_grad.