Allow unbacked to unbacked replacements if rhs unbacked symbols are all inputs (#163652)

This partially solve the issue https://github.com/pytorch/pytorch/issues/163641. We do not need to ban unbacked to unbacked replacement if all rhs symbols are inputs since we know those symbols are seen by the whole program.

This issue was found as i was tracing some vllm models with unbacked, namely  Qwen/Qwen2-1.5B-Instruct it makes reasoning logic easier to do those replacements.

as for data dependent similar pattern, I am thinking to create a set of replacements that we apply only during static eval
instead of none. to make reasoning better.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163652
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-09-25 14:29:02 -07:00
committed by PyTorch MergeBot
parent 2a45f30ae7
commit b42e81def5
6 changed files with 138 additions and 11 deletions

View File

@ -468,7 +468,7 @@ inline Tensor _sum_to(
// if we assume no reduction due to unbacked we ensure that at runtime.
TORCH_MAYBE_SYM_CHECK(
sym_eq(shape[i - leading_dims], sizes[i]),
"non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:",
"non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:",
shape[i - leading_dims],
", ",
sizes[i])

View File

@ -8005,7 +8005,10 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
torch._dynamo.decorators.mark_unbacked(b, 1)
func(a, b)
func(torch.rand(4, 5), torch.rand(4, 5))
with self.assertRaises(RuntimeError):
# This does not raise an error right now because of a recompilation.
# https://github.com/pytorch/pytorch/issues/163785
# with self.assertRaises(AssertionError):
# func(torch.rand(1, 1), torch.rand(2, 1))
func(torch.rand(1, 1), torch.rand(2, 1))
@torch._dynamo.config.patch(capture_scalar_outputs=True)

View File

@ -3318,9 +3318,108 @@ class TestUnbacked(TestCase):
torch._dynamo.decorators.mark_unbacked(b, 0)
func(a, b)
with self.assertRaises(RuntimeError):
# inductor adds the check sometimes itself so it will be reflected
# as AssertionError.
with self.assertRaises((AssertionError, RuntimeError)):
func(a, torch.rand(2, 1))
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_do_not_guard_unbacked_inputs(self):
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
def func(a, b):
a.expand(b.shape)
return a * 10
a = torch.rand(1, 1)
b = torch.rand(1, 1)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(a, 1)
torch._dynamo.decorators.mark_unbacked(b, 0)
torch._dynamo.decorators.mark_unbacked(b, 1)
log_stream, ctx = logs_to_string("torch._dynamo.guards", "guards")
with ctx():
func(a, b)
func(torch.rand(4, 5), torch.rand(4, 5))
guards = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards)
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_div_unabacked_eq_input_tensors(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.size()[0]
y = b.size()[0]
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
a = torch.randn(10, 10)
b = torch.randn(10, 20)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(b, 0)
func(a, b)
@torch.compiler.config.patch(unbacked_sources="L['x'],L['y']")
def test_div_unabacked_eq_input_ints(self):
@torch.compile(fullgraph=True)
def func(x, y):
a = torch.rand(1)
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
func(10, 10)
@skipIfTorchDynamo("mark_unbacked is not traceable")
@torch.compiler.config.patch(unbacked_sources="L['y']")
def test_div_unabacked_eq_globals(self):
tensor = torch.rand(10, 44)
y = 10
@torch.compile(fullgraph=True, dynamic=True)
def func():
a = torch.rand(1)
x = tensor.size()[0]
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
torch._dynamo.decorators.mark_unbacked(tensor, 0)
func()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_div_unabacked_eq_item(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.item()
y = b.item()
torch._check(x == y)
# TODO we should not need those torch checks.
torch._check(x // y == 1)
torch._check(2 * x // y == 2)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
a = torch.tensor([1])
b = torch.tensor([1])
func(a, b)
class TestUbackedOps(TestCase):
@fresh_cache()

View File

@ -41,7 +41,9 @@ def process_inputs(
return x
source = ConstantSource(f"sym_{idx}")
return shape_env.create_symintnode(
shape_env.create_symbol(x, source), hint=x, source=source
shape_env.create_symbol(x, source),
hint=x,
source=source,
)
if isinstance(x, torch.ScriptObject):
return torch._library.fake_class_registry.maybe_to_fake_obj(

View File

@ -1858,7 +1858,7 @@ class GraphLowering(torch.fx.Interpreter):
shape_env = V.graph.sizevars.shape_env
# An input can be unbacked symint i.e.: when mark_unabcked is used.
# An input can be unbacked symint i.e.: when mark_unbacked is used.
# in that case add it to new_unbacked_defs.
if (
n.op == "placeholder"

View File

@ -967,6 +967,16 @@ def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]:
)
def _free_non_source_unbacked_symbols(
x: IterateExprs, unbacked_inputs: OrderedSet[sympy.Symbol]
) -> OrderedSet[sympy.Symbol]:
"""Unbacked symbols that are not inputs to the graph. These are symbols that originated from
data-dependent operations as opposed to mark_unbacked calls."""
unbacked_symbols = free_unbacked_symbols(x)
non_source_symbols = unbacked_symbols - unbacked_inputs
return non_source_symbols
# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
# setup!
def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
@ -3713,6 +3723,8 @@ class ShapeEnv:
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
self.source_name_to_debug_name: dict[str, str] = {}
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
# A set of unbacked symbols that are inputs (i.e: not data dependent).
self.unbacked_inputs: OrderedSet[sympy.Symbol] = OrderedSet()
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
self.var_to_hint_override: dict[sympy.Symbol, int] = {}
# Maps a source to the *original* symbol that was assigned to it
@ -4853,7 +4865,6 @@ class ShapeEnv:
self._log_create_unbacked_symbol(
"create_unbacked_symint", symbol, vr, source, sym_node=sym_node
)
return SymInt(sym_node)
def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
@ -4971,6 +4982,9 @@ class ShapeEnv:
if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE):
out = self.create_unbacked_symint(source).node.expr
self._constrain_range_for_size(out)
self.unbacked_inputs.add(out)
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
source_name
@ -5652,6 +5666,7 @@ class ShapeEnv:
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
# This does a lot of work: it covers duck sizing and equality guards.
all_exprs: list[list[str]] = [[] for _ in langs]
self.dim_constraints = DimConstraints(
symbol_to_source,
self.var_to_val,
@ -5828,6 +5843,7 @@ class ShapeEnv:
is not None
):
continue
issue_guard(guard)
# Because there are guards that export's constraint solver can suggest good fixes for, that we may have
@ -5839,6 +5855,7 @@ class ShapeEnv:
if self._maybe_evaluate_static(ra.expr, axioms=()) is not None:
continue
expr = self.simplify(ra.expr)
self.dim_constraints.add(expr)
# 3. Every symbol must be within its value range (this handles 0/1
@ -6645,6 +6662,7 @@ class ShapeEnv:
Adds or updates a replacement for a symbol.
Use this instead of `self.replacements[a] = tgt`.
"""
if tgt == self.replacements.get(a, None):
return
@ -6902,7 +6920,10 @@ class ShapeEnv:
):
raise NotImplementedError
# Never replace unbacked symbols with other unbacked symbols.
# Never replace unbacked symbols with other unbacked symbols that are
# not function arguments. (ex:mark_unbacked symbols are fine to replace
# other unbacked, but not those coming from .item() calls).
# This is error prone because you can cause references to
# unbacked symbols to time travel backwards. E.g.,
#
@ -6918,8 +6939,10 @@ class ShapeEnv:
# dependencies for substitutions, so ban it entirely.
def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool:
if isinstance(lhs, sympy.Symbol):
if free_unbacked_symbols(lhs) and not free_unbacked_symbols(
rhs
if free_unbacked_symbols(
lhs
) and not _free_non_source_unbacked_symbols(
rhs, self.unbacked_inputs
):
return True
if symbol_is_type(lhs, SymT.FLOAT):
@ -7408,7 +7431,6 @@ class ShapeEnv:
forcing_spec: bool = False,
) -> sympy.Basic:
# TODO: split conjunctions and evaluate them separately
if isinstance(
orig_expr,
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
@ -7753,6 +7775,7 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted(
(s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)),