[export][cond] support merging constant ints as unbacked symint (#152742)

@pianpwk points out that this will be helpful to address several data dependent issues in huggingface [models](e23705e557/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (L332)) with the following pattern:
```python
idx = return 0 if u0 else return 1
return  x[idx]
```
We could preserve the conditional with a cond.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152742
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu
2025-05-21 11:30:28 -07:00
committed by PyTorch MergeBot
parent 025c5cc048
commit fc859077a0
8 changed files with 163 additions and 39 deletions

View File

@ -3134,14 +3134,14 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
)
pred = torch.tensor(True)
for pytree_in in [(1,), ("string",), (1.0,)]:
for pytree_in in [("string",), (1.0,)]:
with self.assertRaisesRegex(
RuntimeError,
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
):
fn(pred, pytree_in)
for pytree_in in [(1,), ("string",), (1.0,)]:
for pytree_in in [("string",), (1.0,)]:
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
r"Cond doesn't work unless it is captured completely with torch.compile",

View File

@ -1355,6 +1355,98 @@ graph():
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),), strict=strict)
def test_cond_branches_return_constant_int(self):
class M(torch.nn.Module):
def forward(self, x):
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
return x[idx]
args = (torch.randn(3, 3),)
m = M()
ep = export(M(), args)
if self._testMethodName == "test_cond_branches_return_constant_int":
self.assertExpectedInline(
normalize_gm(ep.module().print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[3, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
getitem_1: "Sym(u0)" = cond[0]; cond = None
ge_1: "Sym(u0 >= 0)" = getitem_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 1 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
select: "f32[3]" = torch.ops.aten.select.int(x, 0, getitem_1); x = getitem_1 = None
return pytree.tree_unflatten((select,), self._out_spec)
class true_graph_0(torch.nn.Module):
def forward(self):
return (0,)
class false_graph_0(torch.nn.Module):
def forward(self):
return (1,)
""", # noqa: B950
)
self.assertEqual(m(*args), ep.module()(*args))
def test_cond_branches_return_same_int(self):
class M(torch.nn.Module):
def forward(self, x):
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 0, tuple())
return x[idx]
args = (torch.randn(3, 3),)
m = M()
ep = export(M(), args)
# Ideally, we could remove the cond at the front end directly
# since it's not used anyway. But we can only do this early
# optimization if all the outputs are the same constants, which
# will complicates the output check so just keep it in the graph.
# let downstream to dce it.
if self._testMethodName == "test_cond_branches_return_same_int":
self.assertExpectedInline(
normalize_gm(ep.module().print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[3, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = getitem = None
select: "f32[3]" = torch.ops.aten.select.int(x, 0, 0); x = None
return pytree.tree_unflatten((select,), self._out_spec)
class true_graph_0(torch.nn.Module):
def forward(self):
return (0,)
class false_graph_0(torch.nn.Module):
def forward(self):
return (0,)
""", # noqa: B950
)
self.assertEqual(m(*args), ep.module()(*args))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_cond_contains_unbacked_no_escape(self):
class M(torch.nn.Module):

View File

@ -8328,10 +8328,10 @@ class GraphModule(torch.nn.Module):
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
@skipIfTorchDynamo(
"Skip because _merge_tensors is not intended for dynamo to compile"
"Skip because _merge_output is not intended for dynamo to compile"
)
def test_merge_tensors(self):
from torch._higher_order_ops.cond import _merge_tensors
def test_merge_output(self):
from torch._higher_order_ops.cond import _merge_output
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -8376,7 +8376,7 @@ class GraphModule(torch.nn.Module):
with fake_mode:
t1 = torch.empty_strided(size1, stride1)
t2 = torch.empty_strided(size2, stride2)
out = _merge_tensors(t1, t2, fake_mode)
out = _merge_output(t1, t2, fake_mode)
self.assertEqual(str(tuple(out.size())), merged_size)
self.assertEqual(str(tuple(out.stride())), merged_stride)

View File

@ -1067,9 +1067,16 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
supports_aliasing=self.supports_aliasing,
)
if not only_consist_of(ret_val, (TensorVariable,)):
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
unimplemented(
"Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
"Expected branches to return a possibly nested pytree of tensors "
"or constant ints but it consists of others.",
)
for ret in ret_val.unpack_var_sequence(tx):
if isinstance(ret, ConstantVariable) and ret.python_type() is not int:
unimplemented(
"Expected branches to return a possibly nested pytree of tensors "
f"or constant ints but it consists of others {ret.python_type()}.",
)
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars

View File

@ -1641,11 +1641,13 @@ class GraphModuleDeserializer(metaclass=Final):
self.module,
self.serialized_name_to_node,
self.serialized_name_to_meta,
self.unbacked_symbols
)
self.graph = torch.fx.Graph()
self.module = torch.nn.Module()
self.serialized_name_to_node = {}
self.serialized_name_to_meta = {}
self.unbacked_symbols: set[sympy.Symbol] = set()
try:
yield
finally:
@ -1654,6 +1656,7 @@ class GraphModuleDeserializer(metaclass=Final):
self.module,
self.serialized_name_to_node,
self.serialized_name_to_meta,
self.unbacked_symbols
) = saved
def deserialize_extension_operator(self, serialized_target: str):
@ -2184,7 +2187,7 @@ class GraphModuleDeserializer(metaclass=Final):
self.symbol_name_to_range = {}
# we also need to bump unbacked sym[float,int] counters in the
# shape env to accommodate unbacked symbols in the exported program
self.unbacked_symbols: set[sympy.Symbol] = set()
self.unbacked_symbols = set()
count_unbacked_symfloat, count_unbacked_symint = -1, -1
unbacked_symfloat_prefix, unbacked_symint_prefix = (
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
@ -2422,27 +2425,34 @@ class GraphModuleDeserializer(metaclass=Final):
# Check single value return
if len(serialized_node.outputs) == 0:
return
if (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
):
# If it is a HOP node and it returns a tuple containing a single element
# we manually insert a getitem node to ensure the graph is consistent
# For BC, getattr() will return True if `is_single_tensor_return` doens't exist
# as prior to adding this field, it is guaranteed to have a single tensor return
# when the serialized_node has length=1 outputs and of type `as_tensor`.
if (
"torch.ops.higher_order" in serialized_node.target
and "torch.ops.higher_order" in serialized_node.target
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
):
def _deserialize_hop_with_single_return(serialized_node, fx_node):
meta_val: list[Any] = []
arg = None
if serialized_node.outputs[0].type == "as_tensor":
arg = serialized_node.outputs[0].as_tensor
elif isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
arg = serialized_node.outputs[0].value
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
assert arg is not None
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
return
return _deserialize_hop_with_single_return(serialized_node, fx_node)
if (
len(serialized_node.outputs) == 1
and serialized_node.outputs[0].type == "as_tensor"
):
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
return
elif len(serialized_node.outputs) == 1 and isinstance(

View File

@ -42,6 +42,7 @@ __all__ = [
"while_loop",
"invoke_subgraph",
"scan",
"map",
"flex_attention",
"flex_attention_backward",
"hints_wrapper",

View File

@ -99,7 +99,9 @@ def cond(
false_fn (Callable): A callable function (a -> b) that is within the
scope that is being traced. The true branch and false branch must
have consistent input and outputs, meaning the inputs have to be
the same, and the outputs have to be the same type and shape.
the same, and the outputs have to be the same type and shape. Int
output is also allowed. We'll make the output dynamic by turning it
into a symint.
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
@ -429,7 +431,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
merged_outs = []
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
merged_outs.append(_merge_tensors(true_out, false_out, mode))
merged_outs.append(_merge_output(true_out, false_out, mode))
return pytree.tree_unflatten(merged_outs, true_out_spec)
@ -451,8 +453,10 @@ def check_tensor_meta_match(
)
def _merge_tensors(
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
def _merge_output(
a: Optional[Union[torch.Tensor, int]],
b: Optional[Union[torch.Tensor, int]],
mode: FakeTensorMode,
):
from torch.fx.experimental.symbolic_shapes import (
has_free_unbacked_symbols,
@ -463,6 +467,28 @@ def _merge_tensors(
assert a is None and b is None, (a, b)
return None
def min_max(s0, s1):
def _bound(s0, lower_bound: bool):
if isinstance(s0, int):
return s0
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
s0.node.expr,
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
)
return r.lower if lower_bound else r.upper
return min(_bound(s0, True), _bound(s1, True)), max(
_bound(s0, False), _bound(s1, False)
)
if type(a) is int and type(b) is int:
if a == b:
return a
assert mode.shape_env is not None
merged_out = mode.shape_env.create_unbacked_symint()
mode.shape_env.constrain_symbol_range(merged_out.node.expr, *min_max(a, b))
return merged_out
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
# Note: we don't check size, stride because
@ -517,21 +543,6 @@ def _merge_tensors(
):
merged_size.append(s0)
else:
def min_max(s0, s1):
def _bound(s0, lower_bound: bool):
if isinstance(s0, int):
return s0
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
s0.node.expr,
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
)
return r.lower if lower_bound else r.upper
return min(_bound(s0, True), _bound(s1, True)), max(
_bound(s0, False), _bound(s1, False)
)
assert mode.shape_env is not None
new_size = mode.shape_env.create_unbacked_symint()
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))

View File

@ -792,6 +792,9 @@ def check_input_alias_and_mutation_return_ouputs(
# has a persistent fake mode but fake tensors can be created
# outside of the tracing context (e.g. in testing).
# Instead, we just look at fake_args fake tensor mode
if len(fake_args) == 0:
return torch.fx.experimental.symbolic_shapes.ShapeEnv()
prev_fake_mode = None
for arg in fake_args:
if isinstance(arg, torch.Tensor):