Fix silent incorrectness arising from incorrect alias information (#152011)

Fixes #136662

There are two problems:
1) canonicalize_view_scatter_ops adds some new nodes into the graph.
   These new nodes cause the alias info on the graph to be wrong. To fix
   this, we try to run FakeTensorUpdater on the graph again.
2) FakeTensorUpdater's alias information is wrong. It tries to skip
   nodes that it thinks have "equivalent" FakeTensor metadata.
   It should not be allowed to do this if any users of the node can
   alias the node. The example
   is if we have `x = foo(...); y = x.view(...)`. If the user replaces
   `foo` with a new `bar` node and sets bar.meta["val"] correctly, then
   FakeTensorUpdater still needs to update y's meta["val"] to be a view
   of the new bar node.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152011
Approved by: https://github.com/yf225
This commit is contained in:
rzou
2025-06-26 21:06:54 -07:00
committed by PyTorch MergeBot
parent 75f3e5a88d
commit 43523bf168
4 changed files with 96 additions and 5 deletions

View File

@ -413,6 +413,31 @@ class TestReinplacingPassCorrectness(InductorTestCase):
# Both list inputs failed to reinplace. So we should have emitted clones for them.
self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
def test_generalized_scatter(self):
# This is an integration test for the reinplacing pass.
def fn(x_1):
a = torch.ones([2, 3])
c = torch.ones(2)
a[:, 0].copy_(c)
d = a.clone()
e = torch.ops.aten.as_strided.default(d, [2], [3], 0)
f = e.clone()
g = torch.zeros(2)
e.copy_(g)
h = torch.zeros(2, 3)
h[:, 0].copy_(f)
add_1 = d + h
return add_1
x = torch.randn(2, 3)
expected = fn(x)
result = torch.compile(fn, fullgraph=True, backend="inductor")(x)
self.assertEqual(result, expected)
@parametrize(
"factory_op",
[

View File

@ -200,7 +200,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
# Keep these last, since they introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
reinplace_inplaceable_ops
functools.partial(reinplace_inplaceable_ops, fake_tensor_updater),
)
GraphTransformObserver(
gm, "decompose_triton_kernel_wrapper_functional"

View File

@ -759,8 +759,15 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
graph.erase_node(node)
def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
def reinplace_inplaceable_ops(
fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater,
graph: torch.fx.Graph,
) -> None:
with enable_python_dispatcher():
canonicalize_view_scatter_ops(graph)
# canonicalize_view_scatter_ops adds new operations to the graph.
# We run fake_tensor_updater to update the alias information.
# Correct alias information is required for `reinplace_inplaceable_ops_core`.
fake_tensor_updater.incremental_update()
reinplace_inplaceable_ops_core(graph)
decompose_generalized_scatter(graph)

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import contextlib
import operator
from collections import defaultdict
from typing import Any, Callable, Optional
@ -88,6 +89,7 @@ class FakeTensorUpdater:
return (node, node.target, id(node.args), id(node.kwargs))
def incremental_update(self):
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
@ -95,14 +97,15 @@ class FakeTensorUpdater:
def is_intlist_same(new, old):
return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old):
def is_fake_tensor_same(new, old, *, node):
if type(new) != type(old):
return False
if isinstance(new, (list, tuple)):
if len(new) != len(old):
return False
return all(
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
is_fake_tensor_same(new_i, old_i, node=node)
for new_i, old_i in zip(new, old)
)
if new is None:
return old is None
@ -132,12 +135,61 @@ class FakeTensorUpdater:
if get_storage(new) == get_storage(old):
return True
def any_user_may_alias(node):
if not isinstance(node.meta["val"], torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
for user in node.users:
if not (
isinstance(
user.target,
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
)
or user.target
== torch._inductor.fx_passes.reinplace._generalized_scatter
):
return True
if isinstance(user.target, torch._ops.HigherOrderOperator):
# HOPs that survive until inductor are all non-aliasing HOPs.
# We will likely never support HOPs that are aliasing.
continue
# Strategy: do a FakeTensor prop, see if the storage aliases.
# If Inductor ever gets tighter invariants on OpOverloads
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
# Then this could just be a fast schema lookup.
is_valid, args, kwargs = get_fake_args_kwargs(user)
if not is_valid:
return True
with (
V.fake_mode,
enable_python_dispatcher(),
contextlib.ExitStack() as stack,
):
# Ignore unbacked symbols (if they exist): we're making
# this FakeTensor and then throwing it away.
shape_env = V.fake_mode.shape_env
if shape_env is not None:
stack.enter_context(
shape_env.ignore_fresh_unbacked_symbols()
)
new_fake_tensor = user.target(*args, **kwargs)
if not isinstance(new_fake_tensor, torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
return True
return False
# This is the case where it returns a completely fresh storage that's used nowhere else.
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
# we don't need to update this node.
if (
existing_storages[get_storage(old)] == 1
and get_storage(new) not in existing_storages
and not any_user_may_alias(node)
):
return True
return False
def should_process_node(node):
@ -149,10 +201,16 @@ class FakeTensorUpdater:
return node.op == "call_function" and (
isinstance(node.target, torch._ops.OpOverload)
or node.target == operator.getitem
or node.target
== torch._inductor.fx_passes.reinplace._generalized_scatter
)
to_process = OrderedSet[int]()
for node in self.graph.nodes:
# NB: Be very careful about skipping nodes (via continues) here
# and ask for a careful review when changing this code. The
# consequence for incorrect FakeTensor metadata is difficult-to-debug
# silent incorrectness.
if (
self.hash_node(node) in self.processed_hashes
and id(node) not in to_process
@ -167,8 +225,9 @@ class FakeTensorUpdater:
continue
with V.fake_mode, enable_python_dispatcher():
new_fake_tensor = node.target(*args, **kwargs)
if "val" in node.meta and is_fake_tensor_same(
new_fake_tensor, node.meta["val"]
new_fake_tensor, node.meta["val"], node=node
):
continue