Fix issues with generalized_scatter and setitem allocated unbacked symbols. (#164341)

Three fixes:
1. When doing t[u0] +=1  if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0.  but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it.

2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints
but those do not effect final output, we also shall ignore them.

3.Before accessing strides in lowering we shall materialize.

Address  https://github.com/pytorch/pytorch/issues/114293 and https://github.com/pytorch/pytorch/issues/131911

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164341
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka
2025-10-17 11:00:15 -07:00
committed by PyTorch MergeBot
parent de09bab4b6
commit c6a8db0b9a
4 changed files with 72 additions and 7 deletions

View File

@ -3398,7 +3398,7 @@ class TestUnbacked(TestCase):
self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards)
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_div_unabacked_eq_input_tensors(self):
def test_div_unbacked_eq_input_tensors(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.size()[0]
@ -3418,7 +3418,7 @@ class TestUnbacked(TestCase):
func(a, b)
@torch.compiler.config.patch(unbacked_sources="L['x'],L['y']")
def test_div_unabacked_eq_input_ints(self):
def test_div_unbacked_eq_input_ints(self):
@torch.compile(fullgraph=True)
def func(x, y):
a = torch.rand(1)
@ -3433,7 +3433,7 @@ class TestUnbacked(TestCase):
@skipIfTorchDynamo("mark_unbacked is not traceable")
@torch.compiler.config.patch(unbacked_sources="L['y']")
def test_div_unabacked_eq_globals(self):
def test_div_unbacked_eq_globals(self):
tensor = torch.rand(10, 44)
y = 10
@ -3452,7 +3452,7 @@ class TestUnbacked(TestCase):
func()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_div_unabacked_eq_item(self):
def test_div_unbacked_eq_item(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.item()
@ -4270,6 +4270,37 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
result_compiled = compiled_program()
self.assertEqual(result_original, result_compiled)
def test_unbacked_item_set_item(self):
def my_arithmetic(a, b):
wrk = torch.zeros(a.size(0))
for i in range(a.size(0)):
idx = b[i].item()
wrk[idx] += 1
return wrk
compiled = torch.compile(my_arithmetic, fullgraph=True, disable=False)
a = torch.randn([9])
b = torch.ones(9, dtype=torch.int32)
compiled(a, b)
self.assertEqual(compiled(a, b), my_arithmetic(a, b))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item_set_item2(self):
def accumulate(X0, start):
start = start.item()
N = 3
result = X0[start]
for i in range(0, N):
result += X0[start + 1 + i]
return result
compiled = torch.compile(accumulate, fullgraph=True)
X0 = torch.randn(10, 10)
self.assertEqual(
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
)
instantiate_parametrized_tests(TestUnbacked)

View File

@ -23,6 +23,7 @@ import operator
import textwrap
import traceback
import types
from contextlib import nullcontext
from typing import TYPE_CHECKING
import sympy
@ -1109,7 +1110,19 @@ class TensorVariable(VariableTracker):
# value.requires_grad is True => self.has_grad_fn becomes True
# Not sure if __setitem__ can ever save activations, disabling just in case
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
# Ignore fresh unbacked symbols that could arise from the internal indexing (selection),
# that happen in code like t[idx] += 1 when idx is unbacked. Namely the selection
# during 'setitem'.
# When the selection happens if idx is unbacked we allocate a new unbacked symbol for the
# storage offset in select_meta, but the output of the operation 'setitem' does not depend
# on the selection.
with (
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols()
if tx.fake_mode and tx.fake_mode.shape_env
else nullcontext(),
):
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
vt = value

View File

@ -4,6 +4,7 @@ import logging
import operator
from collections import defaultdict
from collections.abc import Sequence
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, cast
@ -12,6 +13,7 @@ import torch.fx.node
from torch._C._dynamo.guards import compute_overlapping_tensors
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
from torch._guards import detect_fake_mode
from torch._higher_order_ops.triton_kernel_wrap import (
kernel_side_table,
triton_kernel_wrapper_functional,
@ -78,6 +80,14 @@ def _inplace_generalized_scatter(
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
(view.args, view.kwargs),
)
# slice and select can allocate new unbacked symints, but those won't be reflected
# in the output of this function, hence shall be ignored.
fake_mode = detect_fake_mode(fake_args)
with (
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
if fake_mode and fake_mode.shape_env
else nullcontext()
):
tmp = view.target(tmp, *fake_args, **fake_kwargs)
try:
tmp.copy_(src)

View File

@ -1956,6 +1956,9 @@ def select(x, dim, idx):
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
# we use as_strided instead.
# Removing this branch will cause test_unbacked_select_index_with_check to fail.
# before accessing size, stride, and offset we need to realize.
x.realize()
new_size = x.get_size()
new_stride = x.get_stride()
new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
@ -1979,6 +1982,8 @@ def select(x, dim, idx):
assert len(unbacked_bindings) == 1, unbacked_bindings
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
# before accessing size, stride, and offset we need to realize.
x.realize()
new_size = x.get_size()
new_stride = x.get_stride()
new_storage_offset = unbacked_offset_sym
@ -3159,8 +3164,14 @@ def select_scatter(x, src, dim: int, index: int):
assert x.get_dtype() == src.get_dtype()
x_loader = x.make_loader()
dim = _validate_dim(x, dim, 0)
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
if V.graph.sizevars.guard_or_false(sympy.Lt(index, 0)):
index = index + x.get_size()[dim]
elif V.graph.sizevars.guard_or_false(sympy.Ge(index, 0)):
pass
else:
# unbacked index
return fallback_handler(aten.select_scatter.default)(x, src, dim, index)
V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
src = expand(unsqueeze(src, dim), x.get_size())