mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Allow unbacked to unbacked replacements if rhs unbacked symbols are all inputs (#163652)
This partially solve the issue https://github.com/pytorch/pytorch/issues/163641. We do not need to ban unbacked to unbacked replacement if all rhs symbols are inputs since we know those symbols are seen by the whole program. This issue was found as i was tracing some vllm models with unbacked, namely Qwen/Qwen2-1.5B-Instruct it makes reasoning logic easier to do those replacements. as for data dependent similar pattern, I am thinking to create a set of replacements that we apply only during static eval instead of none. to make reasoning better. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163652 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a45f30ae7
commit
b42e81def5
@ -468,7 +468,7 @@ inline Tensor _sum_to(
|
||||
// if we assume no reduction due to unbacked we ensure that at runtime.
|
||||
TORCH_MAYBE_SYM_CHECK(
|
||||
sym_eq(shape[i - leading_dims], sizes[i]),
|
||||
"non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:",
|
||||
"non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:",
|
||||
shape[i - leading_dims],
|
||||
", ",
|
||||
sizes[i])
|
||||
|
@ -8005,8 +8005,11 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
torch._dynamo.decorators.mark_unbacked(b, 1)
|
||||
func(a, b)
|
||||
func(torch.rand(4, 5), torch.rand(4, 5))
|
||||
with self.assertRaises(RuntimeError):
|
||||
func(torch.rand(1, 1), torch.rand(2, 1))
|
||||
# This does not raise an error right now because of a recompilation.
|
||||
# https://github.com/pytorch/pytorch/issues/163785
|
||||
# with self.assertRaises(AssertionError):
|
||||
# func(torch.rand(1, 1), torch.rand(2, 1))
|
||||
func(torch.rand(1, 1), torch.rand(2, 1))
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_sym_constrain_range_on_replaced_unbacked_symbol(self):
|
||||
|
@ -3318,9 +3318,108 @@ class TestUnbacked(TestCase):
|
||||
torch._dynamo.decorators.mark_unbacked(b, 0)
|
||||
func(a, b)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
# inductor adds the check sometimes itself so it will be reflected
|
||||
# as AssertionError.
|
||||
with self.assertRaises((AssertionError, RuntimeError)):
|
||||
func(a, torch.rand(2, 1))
|
||||
|
||||
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
||||
def test_do_not_guard_unbacked_inputs(self):
|
||||
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
|
||||
def func(a, b):
|
||||
a.expand(b.shape)
|
||||
return a * 10
|
||||
|
||||
a = torch.rand(1, 1)
|
||||
b = torch.rand(1, 1)
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(a, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(a, 1)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 1)
|
||||
|
||||
log_stream, ctx = logs_to_string("torch._dynamo.guards", "guards")
|
||||
with ctx():
|
||||
func(a, b)
|
||||
func(torch.rand(4, 5), torch.rand(4, 5))
|
||||
|
||||
guards = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards)
|
||||
|
||||
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
||||
def test_div_unabacked_eq_input_tensors(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(a, b):
|
||||
x = a.size()[0]
|
||||
y = b.size()[0]
|
||||
torch._check(x == y)
|
||||
if x // y == 1:
|
||||
a = a * 10
|
||||
if 2 * x // y == 2:
|
||||
a = a * 20
|
||||
return a
|
||||
|
||||
a = torch.randn(10, 10)
|
||||
b = torch.randn(10, 20)
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(a, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(b, 0)
|
||||
func(a, b)
|
||||
|
||||
@torch.compiler.config.patch(unbacked_sources="L['x'],L['y']")
|
||||
def test_div_unabacked_eq_input_ints(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x, y):
|
||||
a = torch.rand(1)
|
||||
torch._check(x == y)
|
||||
if x // y == 1:
|
||||
a = a * 10
|
||||
if 2 * x // y == 2:
|
||||
a = a * 20
|
||||
return a
|
||||
|
||||
func(10, 10)
|
||||
|
||||
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
||||
@torch.compiler.config.patch(unbacked_sources="L['y']")
|
||||
def test_div_unabacked_eq_globals(self):
|
||||
tensor = torch.rand(10, 44)
|
||||
y = 10
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=True)
|
||||
def func():
|
||||
a = torch.rand(1)
|
||||
x = tensor.size()[0]
|
||||
torch._check(x == y)
|
||||
if x // y == 1:
|
||||
a = a * 10
|
||||
if 2 * x // y == 2:
|
||||
a = a * 20
|
||||
return a
|
||||
|
||||
torch._dynamo.decorators.mark_unbacked(tensor, 0)
|
||||
func()
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_div_unabacked_eq_item(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(a, b):
|
||||
x = a.item()
|
||||
y = b.item()
|
||||
torch._check(x == y)
|
||||
# TODO we should not need those torch checks.
|
||||
torch._check(x // y == 1)
|
||||
torch._check(2 * x // y == 2)
|
||||
if x // y == 1:
|
||||
a = a * 10
|
||||
if 2 * x // y == 2:
|
||||
a = a * 20
|
||||
return a
|
||||
|
||||
a = torch.tensor([1])
|
||||
b = torch.tensor([1])
|
||||
func(a, b)
|
||||
|
||||
|
||||
class TestUbackedOps(TestCase):
|
||||
@fresh_cache()
|
||||
|
@ -41,7 +41,9 @@ def process_inputs(
|
||||
return x
|
||||
source = ConstantSource(f"sym_{idx}")
|
||||
return shape_env.create_symintnode(
|
||||
shape_env.create_symbol(x, source), hint=x, source=source
|
||||
shape_env.create_symbol(x, source),
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
if isinstance(x, torch.ScriptObject):
|
||||
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
|
@ -1858,7 +1858,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
shape_env = V.graph.sizevars.shape_env
|
||||
|
||||
# An input can be unbacked symint i.e.: when mark_unabcked is used.
|
||||
# An input can be unbacked symint i.e.: when mark_unbacked is used.
|
||||
# in that case add it to new_unbacked_defs.
|
||||
if (
|
||||
n.op == "placeholder"
|
||||
|
@ -967,6 +967,16 @@ def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]:
|
||||
)
|
||||
|
||||
|
||||
def _free_non_source_unbacked_symbols(
|
||||
x: IterateExprs, unbacked_inputs: OrderedSet[sympy.Symbol]
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
"""Unbacked symbols that are not inputs to the graph. These are symbols that originated from
|
||||
data-dependent operations as opposed to mark_unbacked calls."""
|
||||
unbacked_symbols = free_unbacked_symbols(x)
|
||||
non_source_symbols = unbacked_symbols - unbacked_inputs
|
||||
return non_source_symbols
|
||||
|
||||
|
||||
# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
|
||||
# setup!
|
||||
def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]:
|
||||
@ -3713,6 +3723,8 @@ class ShapeEnv:
|
||||
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
|
||||
self.source_name_to_debug_name: dict[str, str] = {}
|
||||
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
|
||||
# A set of unbacked symbols that are inputs (i.e: not data dependent).
|
||||
self.unbacked_inputs: OrderedSet[sympy.Symbol] = OrderedSet()
|
||||
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
|
||||
self.var_to_hint_override: dict[sympy.Symbol, int] = {}
|
||||
# Maps a source to the *original* symbol that was assigned to it
|
||||
@ -4853,7 +4865,6 @@ class ShapeEnv:
|
||||
self._log_create_unbacked_symbol(
|
||||
"create_unbacked_symint", symbol, vr, source, sym_node=sym_node
|
||||
)
|
||||
|
||||
return SymInt(sym_node)
|
||||
|
||||
def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
|
||||
@ -4971,6 +4982,9 @@ class ShapeEnv:
|
||||
if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE):
|
||||
out = self.create_unbacked_symint(source).node.expr
|
||||
self._constrain_range_for_size(out)
|
||||
|
||||
self.unbacked_inputs.add(out)
|
||||
|
||||
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
|
||||
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
|
||||
source_name
|
||||
@ -5652,6 +5666,7 @@ class ShapeEnv:
|
||||
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
|
||||
# This does a lot of work: it covers duck sizing and equality guards.
|
||||
all_exprs: list[list[str]] = [[] for _ in langs]
|
||||
|
||||
self.dim_constraints = DimConstraints(
|
||||
symbol_to_source,
|
||||
self.var_to_val,
|
||||
@ -5828,6 +5843,7 @@ class ShapeEnv:
|
||||
is not None
|
||||
):
|
||||
continue
|
||||
|
||||
issue_guard(guard)
|
||||
|
||||
# Because there are guards that export's constraint solver can suggest good fixes for, that we may have
|
||||
@ -5839,6 +5855,7 @@ class ShapeEnv:
|
||||
if self._maybe_evaluate_static(ra.expr, axioms=()) is not None:
|
||||
continue
|
||||
expr = self.simplify(ra.expr)
|
||||
|
||||
self.dim_constraints.add(expr)
|
||||
|
||||
# 3. Every symbol must be within its value range (this handles 0/1
|
||||
@ -6645,6 +6662,7 @@ class ShapeEnv:
|
||||
Adds or updates a replacement for a symbol.
|
||||
Use this instead of `self.replacements[a] = tgt`.
|
||||
"""
|
||||
|
||||
if tgt == self.replacements.get(a, None):
|
||||
return
|
||||
|
||||
@ -6902,7 +6920,10 @@ class ShapeEnv:
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
# Never replace unbacked symbols with other unbacked symbols.
|
||||
# Never replace unbacked symbols with other unbacked symbols that are
|
||||
# not function arguments. (ex:mark_unbacked symbols are fine to replace
|
||||
# other unbacked, but not those coming from .item() calls).
|
||||
|
||||
# This is error prone because you can cause references to
|
||||
# unbacked symbols to time travel backwards. E.g.,
|
||||
#
|
||||
@ -6918,8 +6939,10 @@ class ShapeEnv:
|
||||
# dependencies for substitutions, so ban it entirely.
|
||||
def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool:
|
||||
if isinstance(lhs, sympy.Symbol):
|
||||
if free_unbacked_symbols(lhs) and not free_unbacked_symbols(
|
||||
rhs
|
||||
if free_unbacked_symbols(
|
||||
lhs
|
||||
) and not _free_non_source_unbacked_symbols(
|
||||
rhs, self.unbacked_inputs
|
||||
):
|
||||
return True
|
||||
if symbol_is_type(lhs, SymT.FLOAT):
|
||||
@ -7408,7 +7431,6 @@ class ShapeEnv:
|
||||
forcing_spec: bool = False,
|
||||
) -> sympy.Basic:
|
||||
# TODO: split conjunctions and evaluate them separately
|
||||
|
||||
if isinstance(
|
||||
orig_expr,
|
||||
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
|
||||
@ -7753,6 +7775,7 @@ class ShapeEnv:
|
||||
expr = canonicalize_bool_expr(expr)
|
||||
stack = CapturedTraceback.extract(skip=1)
|
||||
ra = RuntimeAssert(expr, msg, stack)
|
||||
|
||||
# TODO: Do this in a way that is less janky than int(s.name[1:])
|
||||
cands = sorted(
|
||||
(s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)),
|
||||
|
Reference in New Issue
Block a user