mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix issues with generalized_scatter and setitem allocated unbacked symbols. (#164341)
Three fixes: 1. When doing t[u0] +=1 if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0. but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it. 2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints but those do not effect final output, we also shall ignore them. 3.Before accessing strides in lowering we shall materialize. Address https://github.com/pytorch/pytorch/issues/114293 and https://github.com/pytorch/pytorch/issues/131911 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164341 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
de09bab4b6
commit
c6a8db0b9a
@ -3398,7 +3398,7 @@ class TestUnbacked(TestCase):
|
||||
self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards)
|
||||
|
||||
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
||||
def test_div_unabacked_eq_input_tensors(self):
|
||||
def test_div_unbacked_eq_input_tensors(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(a, b):
|
||||
x = a.size()[0]
|
||||
@ -3418,7 +3418,7 @@ class TestUnbacked(TestCase):
|
||||
func(a, b)
|
||||
|
||||
@torch.compiler.config.patch(unbacked_sources="L['x'],L['y']")
|
||||
def test_div_unabacked_eq_input_ints(self):
|
||||
def test_div_unbacked_eq_input_ints(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x, y):
|
||||
a = torch.rand(1)
|
||||
@ -3433,7 +3433,7 @@ class TestUnbacked(TestCase):
|
||||
|
||||
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
||||
@torch.compiler.config.patch(unbacked_sources="L['y']")
|
||||
def test_div_unabacked_eq_globals(self):
|
||||
def test_div_unbacked_eq_globals(self):
|
||||
tensor = torch.rand(10, 44)
|
||||
y = 10
|
||||
|
||||
@ -3452,7 +3452,7 @@ class TestUnbacked(TestCase):
|
||||
func()
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_div_unabacked_eq_item(self):
|
||||
def test_div_unbacked_eq_item(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(a, b):
|
||||
x = a.item()
|
||||
@ -4270,6 +4270,37 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
result_compiled = compiled_program()
|
||||
self.assertEqual(result_original, result_compiled)
|
||||
|
||||
def test_unbacked_item_set_item(self):
|
||||
def my_arithmetic(a, b):
|
||||
wrk = torch.zeros(a.size(0))
|
||||
for i in range(a.size(0)):
|
||||
idx = b[i].item()
|
||||
wrk[idx] += 1
|
||||
|
||||
return wrk
|
||||
|
||||
compiled = torch.compile(my_arithmetic, fullgraph=True, disable=False)
|
||||
a = torch.randn([9])
|
||||
b = torch.ones(9, dtype=torch.int32)
|
||||
compiled(a, b)
|
||||
self.assertEqual(compiled(a, b), my_arithmetic(a, b))
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_item_set_item2(self):
|
||||
def accumulate(X0, start):
|
||||
start = start.item()
|
||||
N = 3
|
||||
result = X0[start]
|
||||
for i in range(0, N):
|
||||
result += X0[start + 1 + i]
|
||||
return result
|
||||
|
||||
compiled = torch.compile(accumulate, fullgraph=True)
|
||||
X0 = torch.randn(10, 10)
|
||||
self.assertEqual(
|
||||
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
@ -23,6 +23,7 @@ import operator
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
@ -1109,7 +1110,19 @@ class TensorVariable(VariableTracker):
|
||||
# value.requires_grad is True => self.has_grad_fn becomes True
|
||||
|
||||
# Not sure if __setitem__ can ever save activations, disabling just in case
|
||||
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
|
||||
|
||||
# Ignore fresh unbacked symbols that could arise from the internal indexing (selection),
|
||||
# that happen in code like t[idx] += 1 when idx is unbacked. Namely the selection
|
||||
# during 'setitem'.
|
||||
# When the selection happens if idx is unbacked we allocate a new unbacked symbol for the
|
||||
# storage offset in select_meta, but the output of the operation 'setitem' does not depend
|
||||
# on the selection.
|
||||
with (
|
||||
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
|
||||
tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
if tx.fake_mode and tx.fake_mode.shape_env
|
||||
else nullcontext(),
|
||||
):
|
||||
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
|
||||
|
||||
vt = value
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
@ -12,6 +13,7 @@ import torch.fx.node
|
||||
from torch._C._dynamo.guards import compute_overlapping_tensors
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
kernel_side_table,
|
||||
triton_kernel_wrapper_functional,
|
||||
@ -78,7 +80,15 @@ def _inplace_generalized_scatter(
|
||||
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
||||
(view.args, view.kwargs),
|
||||
)
|
||||
tmp = view.target(tmp, *fake_args, **fake_kwargs)
|
||||
# slice and select can allocate new unbacked symints, but those won't be reflected
|
||||
# in the output of this function, hence shall be ignored.
|
||||
fake_mode = detect_fake_mode(fake_args)
|
||||
with (
|
||||
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
|
||||
if fake_mode and fake_mode.shape_env
|
||||
else nullcontext()
|
||||
):
|
||||
tmp = view.target(tmp, *fake_args, **fake_kwargs)
|
||||
try:
|
||||
tmp.copy_(src)
|
||||
except RuntimeError as e:
|
||||
|
@ -1956,6 +1956,9 @@ def select(x, dim, idx):
|
||||
# Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
|
||||
# we use as_strided instead.
|
||||
# Removing this branch will cause test_unbacked_select_index_with_check to fail.
|
||||
|
||||
# before accessing size, stride, and offset we need to realize.
|
||||
x.realize()
|
||||
new_size = x.get_size()
|
||||
new_stride = x.get_stride()
|
||||
new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
|
||||
@ -1979,6 +1982,8 @@ def select(x, dim, idx):
|
||||
assert len(unbacked_bindings) == 1, unbacked_bindings
|
||||
unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
|
||||
|
||||
# before accessing size, stride, and offset we need to realize.
|
||||
x.realize()
|
||||
new_size = x.get_size()
|
||||
new_stride = x.get_stride()
|
||||
new_storage_offset = unbacked_offset_sym
|
||||
@ -3159,8 +3164,14 @@ def select_scatter(x, src, dim: int, index: int):
|
||||
assert x.get_dtype() == src.get_dtype()
|
||||
x_loader = x.make_loader()
|
||||
dim = _validate_dim(x, dim, 0)
|
||||
if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
|
||||
if V.graph.sizevars.guard_or_false(sympy.Lt(index, 0)):
|
||||
index = index + x.get_size()[dim]
|
||||
elif V.graph.sizevars.guard_or_false(sympy.Ge(index, 0)):
|
||||
pass
|
||||
else:
|
||||
# unbacked index
|
||||
return fallback_handler(aten.select_scatter.default)(x, src, dim, index)
|
||||
|
||||
V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
|
||||
V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
|
||||
src = expand(unsqueeze(src, dim), x.get_size())
|
||||
|
Reference in New Issue
Block a user