mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix silent incorrectness arising from incorrect alias information (#152011)
Fixes #136662 There are two problems: 1) canonicalize_view_scatter_ops adds some new nodes into the graph. These new nodes cause the alias info on the graph to be wrong. To fix this, we try to run FakeTensorUpdater on the graph again. 2) FakeTensorUpdater's alias information is wrong. It tries to skip nodes that it thinks have "equivalent" FakeTensor metadata. It should not be allowed to do this if any users of the node can alias the node. The example is if we have `x = foo(...); y = x.view(...)`. If the user replaces `foo` with a new `bar` node and sets bar.meta["val"] correctly, then FakeTensorUpdater still needs to update y's meta["val"] to be a view of the new bar node. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152011 Approved by: https://github.com/yf225
This commit is contained in:
@ -413,6 +413,31 @@ class TestReinplacingPassCorrectness(InductorTestCase):
|
||||
# Both list inputs failed to reinplace. So we should have emitted clones for them.
|
||||
self.assertEqual(post_grad_graphs.count("aten.clone"), 2)
|
||||
|
||||
def test_generalized_scatter(self):
|
||||
# This is an integration test for the reinplacing pass.
|
||||
def fn(x_1):
|
||||
a = torch.ones([2, 3])
|
||||
c = torch.ones(2)
|
||||
a[:, 0].copy_(c)
|
||||
|
||||
d = a.clone()
|
||||
e = torch.ops.aten.as_strided.default(d, [2], [3], 0)
|
||||
f = e.clone()
|
||||
|
||||
g = torch.zeros(2)
|
||||
e.copy_(g)
|
||||
|
||||
h = torch.zeros(2, 3)
|
||||
h[:, 0].copy_(f)
|
||||
|
||||
add_1 = d + h
|
||||
return add_1
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
expected = fn(x)
|
||||
result = torch.compile(fn, fullgraph=True, backend="inductor")(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@parametrize(
|
||||
"factory_op",
|
||||
[
|
||||
|
@ -200,7 +200,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
# Keep these last, since they introduces mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
|
||||
reinplace_inplaceable_ops
|
||||
functools.partial(reinplace_inplaceable_ops, fake_tensor_updater),
|
||||
)
|
||||
GraphTransformObserver(
|
||||
gm, "decompose_triton_kernel_wrapper_functional"
|
||||
|
@ -759,8 +759,15 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
|
||||
def reinplace_inplaceable_ops(
|
||||
fake_tensor_updater: torch._inductor.fx_utils.FakeTensorUpdater,
|
||||
graph: torch.fx.Graph,
|
||||
) -> None:
|
||||
with enable_python_dispatcher():
|
||||
canonicalize_view_scatter_ops(graph)
|
||||
# canonicalize_view_scatter_ops adds new operations to the graph.
|
||||
# We run fake_tensor_updater to update the alias information.
|
||||
# Correct alias information is required for `reinplace_inplaceable_ops_core`.
|
||||
fake_tensor_updater.incremental_update()
|
||||
reinplace_inplaceable_ops_core(graph)
|
||||
decompose_generalized_scatter(graph)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional
|
||||
@ -88,6 +89,7 @@ class FakeTensorUpdater:
|
||||
return (node, node.target, id(node.args), id(node.kwargs))
|
||||
|
||||
def incremental_update(self):
|
||||
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
|
||||
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
|
||||
for node in self.graph.nodes:
|
||||
existing_storages[get_node_storage(node)] += 1
|
||||
@ -95,14 +97,15 @@ class FakeTensorUpdater:
|
||||
def is_intlist_same(new, old):
|
||||
return statically_known_true(sym_eq(new, old))
|
||||
|
||||
def is_fake_tensor_same(new, old):
|
||||
def is_fake_tensor_same(new, old, *, node):
|
||||
if type(new) != type(old):
|
||||
return False
|
||||
if isinstance(new, (list, tuple)):
|
||||
if len(new) != len(old):
|
||||
return False
|
||||
return all(
|
||||
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
||||
is_fake_tensor_same(new_i, old_i, node=node)
|
||||
for new_i, old_i in zip(new, old)
|
||||
)
|
||||
if new is None:
|
||||
return old is None
|
||||
@ -132,12 +135,61 @@ class FakeTensorUpdater:
|
||||
if get_storage(new) == get_storage(old):
|
||||
return True
|
||||
|
||||
def any_user_may_alias(node):
|
||||
if not isinstance(node.meta["val"], torch.Tensor):
|
||||
# analysis too complicated on lists, can support in the future
|
||||
return True
|
||||
for user in node.users:
|
||||
if not (
|
||||
isinstance(
|
||||
user.target,
|
||||
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
|
||||
)
|
||||
or user.target
|
||||
== torch._inductor.fx_passes.reinplace._generalized_scatter
|
||||
):
|
||||
return True
|
||||
if isinstance(user.target, torch._ops.HigherOrderOperator):
|
||||
# HOPs that survive until inductor are all non-aliasing HOPs.
|
||||
# We will likely never support HOPs that are aliasing.
|
||||
continue
|
||||
# Strategy: do a FakeTensor prop, see if the storage aliases.
|
||||
# If Inductor ever gets tighter invariants on OpOverloads
|
||||
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
|
||||
# Then this could just be a fast schema lookup.
|
||||
is_valid, args, kwargs = get_fake_args_kwargs(user)
|
||||
if not is_valid:
|
||||
return True
|
||||
with (
|
||||
V.fake_mode,
|
||||
enable_python_dispatcher(),
|
||||
contextlib.ExitStack() as stack,
|
||||
):
|
||||
# Ignore unbacked symbols (if they exist): we're making
|
||||
# this FakeTensor and then throwing it away.
|
||||
shape_env = V.fake_mode.shape_env
|
||||
if shape_env is not None:
|
||||
stack.enter_context(
|
||||
shape_env.ignore_fresh_unbacked_symbols()
|
||||
)
|
||||
new_fake_tensor = user.target(*args, **kwargs)
|
||||
if not isinstance(new_fake_tensor, torch.Tensor):
|
||||
# analysis too complicated on lists, can support in the future
|
||||
return True
|
||||
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
# This is the case where it returns a completely fresh storage that's used nowhere else.
|
||||
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
|
||||
# we don't need to update this node.
|
||||
if (
|
||||
existing_storages[get_storage(old)] == 1
|
||||
and get_storage(new) not in existing_storages
|
||||
and not any_user_may_alias(node)
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def should_process_node(node):
|
||||
@ -149,10 +201,16 @@ class FakeTensorUpdater:
|
||||
return node.op == "call_function" and (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
or node.target == operator.getitem
|
||||
or node.target
|
||||
== torch._inductor.fx_passes.reinplace._generalized_scatter
|
||||
)
|
||||
|
||||
to_process = OrderedSet[int]()
|
||||
for node in self.graph.nodes:
|
||||
# NB: Be very careful about skipping nodes (via continues) here
|
||||
# and ask for a careful review when changing this code. The
|
||||
# consequence for incorrect FakeTensor metadata is difficult-to-debug
|
||||
# silent incorrectness.
|
||||
if (
|
||||
self.hash_node(node) in self.processed_hashes
|
||||
and id(node) not in to_process
|
||||
@ -167,8 +225,9 @@ class FakeTensorUpdater:
|
||||
continue
|
||||
with V.fake_mode, enable_python_dispatcher():
|
||||
new_fake_tensor = node.target(*args, **kwargs)
|
||||
|
||||
if "val" in node.meta and is_fake_tensor_same(
|
||||
new_fake_tensor, node.meta["val"]
|
||||
new_fake_tensor, node.meta["val"], node=node
|
||||
):
|
||||
continue
|
||||
|
||||
|
Reference in New Issue
Block a user