Rewrite unsafe_remove_auto_functionalized_pass using decompose_auto_functionalized (#134831)

`unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831
Approved by: https://github.com/zou3519
This commit is contained in:
Laith Sakka
2024-08-29 22:33:26 -07:00
committed by PyTorch MergeBot
parent 351ba3e67c
commit f5b0caee71
3 changed files with 38 additions and 73 deletions

View File

@ -1166,20 +1166,19 @@ def forward(self, add_1):
x = torch.randn([3, 3])
ep = export(mod, (x,))
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
nodes = inplace_ep.graph.nodes
getitems = 0
for node in nodes:
if node.op == "call_function":
self.assertFalse(node.target is auto_functionalized)
if node.target is operator.getitem:
getitems += 1
self.assertEqual(getitems, 2) # tuple return of len 2
out_specs = inplace_ep.graph_signature.output_specs
self.assertEqual(out_specs[0].arg.name, "b_state") # state
self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
graph_text = str(inplace_ep.graph)
self.assertExpectedInline(
graph_text,
"""\
graph():
%b_state : [num_users=2] = placeholder[target=b_state]
%x : [num_users=1] = placeholder[target=x]
%custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\
default](args = (%x, %b_state), kwargs = {})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {})
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {})
return (b_state, getitem_3, getitem_4)""",
)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_move_to_device_pass(self):

View File

@ -236,9 +236,13 @@ class Match:
replacement graph.
"""
from torch._inductor.virtualized import V
from torch._inductor.virtualized import NullHandler, V
context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
context = (
V.fake_mode
if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
else contextlib.nullcontext()
)
with context:
if trace_fn is None:

View File

@ -5,64 +5,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from typing import List
import torch
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
get_mutable_arg_names,
)
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
from torch.export import ExportedProgram
def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
# Update every use of the HOP
for node in reversed(auto_functionalize_nodes):
func = node.args[0]
original_kwargs = node.kwargs
assert isinstance(func, torch._ops.OpOverload)
with ep.graph.inserting_before(node):
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
# Replace auto_functionalize(func, args) with just func(args)
node.replace_all_uses_with(new_node)
mutable_args_names = get_mutable_arg_names(new_node.target)
# update the users of the auto_func node (the getitem nodes)
for user in list(new_node.users.keys()):
assert user.target == operator.getitem
# getitem corresponding to a mutated input, just replace all uses with the original input
if user.args[1] >= len(func._schema.returns):
assert user.args[1] <= len(func._schema.returns) + len(
mutable_args_names
)
# If the result of getitem was used in an output node, update the output spec with the correct name
adjusted_index = user.args[1] - len(func._schema.returns)
original_arg = original_kwargs[mutable_args_names[adjusted_index]]
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
# of the getitem calls following the HOP.
user.replace_all_uses_with(original_arg)
if len(func._schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(auto_functionalized, 0) with just the note itself.
for user in list(new_node.users.keys()):
if user.args[1] == 0:
user.replace_all_uses_with(new_node)
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
ep.graph.erase_node(node)
ep.graph.eliminate_dead_code()
def remove_self_clone(graph: torch.fx.Graph):
for node in graph.nodes:
if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
def unsafe_remove_auto_functionalized_pass(
@ -73,15 +30,20 @@ def unsafe_remove_auto_functionalized_pass(
and modifies the calling EP inplace to have the original mutator op.
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
"""
auto_functionalize_nodes: List[torch.fx.Node] = []
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in ep.graph.nodes:
if node.op == "call_function" and node.target is auto_functionalized:
auto_functionalize_nodes.append(node)
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in ep.graph.nodes:
if node.op == "call_function" and node.target is auto_functionalized:
func = node.args[0]
assert isinstance(func, torch._ops.OpOverload)
mutable_args_names = get_mutable_arg_names(func)
# re-inplace everything
node.meta["only_clone_these_tensors"] = []
decompose_auto_functionalized(ep.graph)
remove_self_clone(ep.graph)
ep.graph.eliminate_dead_code()
return ep