mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export][cond] support merging constant ints as unbacked symint (#152742)
@pianpwk points out that this will be helpful to address several data dependent issues in huggingface [models](e23705e557/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (L332)
) with the following pattern:
```python
idx = return 0 if u0 else return 1
return x[idx]
```
We could preserve the conditional with a cond.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152742
Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
025c5cc048
commit
fc859077a0
@ -3134,14 +3134,14 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
||||
)
|
||||
|
||||
pred = torch.tensor(True)
|
||||
for pytree_in in [(1,), ("string",), (1.0,)]:
|
||||
for pytree_in in [("string",), (1.0,)]:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
||||
):
|
||||
fn(pred, pytree_in)
|
||||
|
||||
for pytree_in in [(1,), ("string",), (1.0,)]:
|
||||
for pytree_in in [("string",), (1.0,)]:
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||
r"Cond doesn't work unless it is captured completely with torch.compile",
|
||||
|
@ -1355,6 +1355,98 @@ graph():
|
||||
M()(torch.randn(7))
|
||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||
|
||||
def test_cond_branches_return_constant_int(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
|
||||
return x[idx]
|
||||
|
||||
args = (torch.randn(3, 3),)
|
||||
m = M()
|
||||
ep = export(M(), args)
|
||||
if self._testMethodName == "test_cond_branches_return_constant_int":
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[3, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
||||
|
||||
getitem_1: "Sym(u0)" = cond[0]; cond = None
|
||||
|
||||
ge_1: "Sym(u0 >= 0)" = getitem_1 >= 0
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
|
||||
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1
|
||||
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 1 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
|
||||
|
||||
select: "f32[3]" = torch.ops.aten.select.int(x, 0, getitem_1); x = getitem_1 = None
|
||||
return pytree.tree_unflatten((select,), self._out_spec)
|
||||
|
||||
class true_graph_0(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (0,)
|
||||
|
||||
class false_graph_0(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (1,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
self.assertEqual(m(*args), ep.module()(*args))
|
||||
|
||||
def test_cond_branches_return_same_int(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 0, tuple())
|
||||
return x[idx]
|
||||
|
||||
args = (torch.randn(3, 3),)
|
||||
m = M()
|
||||
ep = export(M(), args)
|
||||
# Ideally, we could remove the cond at the front end directly
|
||||
# since it's not used anyway. But we can only do this early
|
||||
# optimization if all the outputs are the same constants, which
|
||||
# will complicates the output check so just keep it in the graph.
|
||||
# let downstream to dce it.
|
||||
if self._testMethodName == "test_cond_branches_return_same_int":
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[3, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = getitem = None
|
||||
|
||||
select: "f32[3]" = torch.ops.aten.select.int(x, 0, 0); x = None
|
||||
return pytree.tree_unflatten((select,), self._out_spec)
|
||||
|
||||
class true_graph_0(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (0,)
|
||||
|
||||
class false_graph_0(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (0,)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertEqual(m(*args), ep.module()(*args))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_cond_contains_unbacked_no_escape(self):
|
||||
class M(torch.nn.Module):
|
||||
|
@ -8328,10 +8328,10 @@ class GraphModule(torch.nn.Module):
|
||||
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"Skip because _merge_tensors is not intended for dynamo to compile"
|
||||
"Skip because _merge_output is not intended for dynamo to compile"
|
||||
)
|
||||
def test_merge_tensors(self):
|
||||
from torch._higher_order_ops.cond import _merge_tensors
|
||||
def test_merge_output(self):
|
||||
from torch._higher_order_ops.cond import _merge_output
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
@ -8376,7 +8376,7 @@ class GraphModule(torch.nn.Module):
|
||||
with fake_mode:
|
||||
t1 = torch.empty_strided(size1, stride1)
|
||||
t2 = torch.empty_strided(size2, stride2)
|
||||
out = _merge_tensors(t1, t2, fake_mode)
|
||||
out = _merge_output(t1, t2, fake_mode)
|
||||
self.assertEqual(str(tuple(out.size())), merged_size)
|
||||
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
||||
|
||||
|
@ -1067,10 +1067,17 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
supports_aliasing=self.supports_aliasing,
|
||||
)
|
||||
|
||||
if not only_consist_of(ret_val, (TensorVariable,)):
|
||||
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
|
||||
unimplemented(
|
||||
"Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
|
||||
"Expected branches to return a possibly nested pytree of tensors "
|
||||
"or constant ints but it consists of others.",
|
||||
)
|
||||
for ret in ret_val.unpack_var_sequence(tx):
|
||||
if isinstance(ret, ConstantVariable) and ret.python_type() is not int:
|
||||
unimplemented(
|
||||
"Expected branches to return a possibly nested pytree of tensors "
|
||||
f"or constant ints but it consists of others {ret.python_type()}.",
|
||||
)
|
||||
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
|
||||
|
||||
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
|
||||
|
@ -1641,11 +1641,13 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
self.module,
|
||||
self.serialized_name_to_node,
|
||||
self.serialized_name_to_meta,
|
||||
self.unbacked_symbols
|
||||
)
|
||||
self.graph = torch.fx.Graph()
|
||||
self.module = torch.nn.Module()
|
||||
self.serialized_name_to_node = {}
|
||||
self.serialized_name_to_meta = {}
|
||||
self.unbacked_symbols: set[sympy.Symbol] = set()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -1654,6 +1656,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
self.module,
|
||||
self.serialized_name_to_node,
|
||||
self.serialized_name_to_meta,
|
||||
self.unbacked_symbols
|
||||
) = saved
|
||||
|
||||
def deserialize_extension_operator(self, serialized_target: str):
|
||||
@ -2184,7 +2187,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
self.symbol_name_to_range = {}
|
||||
# we also need to bump unbacked sym[float,int] counters in the
|
||||
# shape env to accommodate unbacked symbols in the exported program
|
||||
self.unbacked_symbols: set[sympy.Symbol] = set()
|
||||
self.unbacked_symbols = set()
|
||||
count_unbacked_symfloat, count_unbacked_symint = -1, -1
|
||||
unbacked_symfloat_prefix, unbacked_symint_prefix = (
|
||||
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
|
||||
@ -2422,27 +2425,34 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
# Check single value return
|
||||
if len(serialized_node.outputs) == 0:
|
||||
return
|
||||
|
||||
if (
|
||||
len(serialized_node.outputs) == 1
|
||||
and serialized_node.outputs[0].type == "as_tensor"
|
||||
and "torch.ops.higher_order" in serialized_node.target
|
||||
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
||||
):
|
||||
# If it is a HOP node and it returns a tuple containing a single element
|
||||
# we manually insert a getitem node to ensure the graph is consistent
|
||||
# For BC, getattr() will return True if `is_single_tensor_return` doens't exist
|
||||
# as prior to adding this field, it is guaranteed to have a single tensor return
|
||||
# when the serialized_node has length=1 outputs and of type `as_tensor`.
|
||||
if (
|
||||
"torch.ops.higher_order" in serialized_node.target
|
||||
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
||||
):
|
||||
def _deserialize_hop_with_single_return(serialized_node, fx_node):
|
||||
meta_val: list[Any] = []
|
||||
arg = serialized_node.outputs[0].as_tensor
|
||||
arg = None
|
||||
if serialized_node.outputs[0].type == "as_tensor":
|
||||
arg = serialized_node.outputs[0].as_tensor
|
||||
elif isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
|
||||
arg = serialized_node.outputs[0].value
|
||||
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
||||
assert arg is not None
|
||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||
fx_node.meta["val"] = tuple(meta_val)
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
return
|
||||
|
||||
return _deserialize_hop_with_single_return(serialized_node, fx_node)
|
||||
|
||||
|
||||
if (
|
||||
len(serialized_node.outputs) == 1
|
||||
and serialized_node.outputs[0].type == "as_tensor"
|
||||
):
|
||||
|
||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||
return
|
||||
elif len(serialized_node.outputs) == 1 and isinstance(
|
||||
|
@ -42,6 +42,7 @@ __all__ = [
|
||||
"while_loop",
|
||||
"invoke_subgraph",
|
||||
"scan",
|
||||
"map",
|
||||
"flex_attention",
|
||||
"flex_attention_backward",
|
||||
"hints_wrapper",
|
||||
|
@ -99,7 +99,9 @@ def cond(
|
||||
false_fn (Callable): A callable function (a -> b) that is within the
|
||||
scope that is being traced. The true branch and false branch must
|
||||
have consistent input and outputs, meaning the inputs have to be
|
||||
the same, and the outputs have to be the same type and shape.
|
||||
the same, and the outputs have to be the same type and shape. Int
|
||||
output is also allowed. We'll make the output dynamic by turning it
|
||||
into a symint.
|
||||
|
||||
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
|
||||
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
|
||||
@ -429,7 +431,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
||||
|
||||
merged_outs = []
|
||||
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
||||
merged_outs.append(_merge_tensors(true_out, false_out, mode))
|
||||
merged_outs.append(_merge_output(true_out, false_out, mode))
|
||||
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
||||
|
||||
|
||||
@ -451,8 +453,10 @@ def check_tensor_meta_match(
|
||||
)
|
||||
|
||||
|
||||
def _merge_tensors(
|
||||
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
|
||||
def _merge_output(
|
||||
a: Optional[Union[torch.Tensor, int]],
|
||||
b: Optional[Union[torch.Tensor, int]],
|
||||
mode: FakeTensorMode,
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
has_free_unbacked_symbols,
|
||||
@ -463,6 +467,28 @@ def _merge_tensors(
|
||||
assert a is None and b is None, (a, b)
|
||||
return None
|
||||
|
||||
def min_max(s0, s1):
|
||||
def _bound(s0, lower_bound: bool):
|
||||
if isinstance(s0, int):
|
||||
return s0
|
||||
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
||||
s0.node.expr,
|
||||
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
||||
)
|
||||
return r.lower if lower_bound else r.upper
|
||||
|
||||
return min(_bound(s0, True), _bound(s1, True)), max(
|
||||
_bound(s0, False), _bound(s1, False)
|
||||
)
|
||||
|
||||
if type(a) is int and type(b) is int:
|
||||
if a == b:
|
||||
return a
|
||||
assert mode.shape_env is not None
|
||||
merged_out = mode.shape_env.create_unbacked_symint()
|
||||
mode.shape_env.constrain_symbol_range(merged_out.node.expr, *min_max(a, b))
|
||||
return merged_out
|
||||
|
||||
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
||||
|
||||
# Note: we don't check size, stride because
|
||||
@ -517,21 +543,6 @@ def _merge_tensors(
|
||||
):
|
||||
merged_size.append(s0)
|
||||
else:
|
||||
|
||||
def min_max(s0, s1):
|
||||
def _bound(s0, lower_bound: bool):
|
||||
if isinstance(s0, int):
|
||||
return s0
|
||||
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
||||
s0.node.expr,
|
||||
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
||||
)
|
||||
return r.lower if lower_bound else r.upper
|
||||
|
||||
return min(_bound(s0, True), _bound(s1, True)), max(
|
||||
_bound(s0, False), _bound(s1, False)
|
||||
)
|
||||
|
||||
assert mode.shape_env is not None
|
||||
new_size = mode.shape_env.create_unbacked_symint()
|
||||
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
||||
|
@ -792,6 +792,9 @@ def check_input_alias_and_mutation_return_ouputs(
|
||||
# has a persistent fake mode but fake tensors can be created
|
||||
# outside of the tracing context (e.g. in testing).
|
||||
# Instead, we just look at fake_args fake tensor mode
|
||||
if len(fake_args) == 0:
|
||||
return torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||
|
||||
prev_fake_mode = None
|
||||
for arg in fake_args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
|
Reference in New Issue
Block a user