mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rewrite unsafe_remove_auto_functionalized_pass
using decompose_auto_functionalized
(#134831)
`unsafe_remove_auto_functionalized_pass` can be written as using `decompose_auto_functionalized`, this way we do not have to update it each time we do a change to `auto_functionalize` (Ex https://github.com/pytorch/pytorch/pull/134409) , and we avoid duplicate logics implemented in two different ways. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134831 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
351ba3e67c
commit
f5b0caee71
@ -1166,20 +1166,19 @@ def forward(self, add_1):
|
||||
x = torch.randn([3, 3])
|
||||
ep = export(mod, (x,))
|
||||
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
|
||||
|
||||
nodes = inplace_ep.graph.nodes
|
||||
getitems = 0
|
||||
for node in nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertFalse(node.target is auto_functionalized)
|
||||
if node.target is operator.getitem:
|
||||
getitems += 1
|
||||
self.assertEqual(getitems, 2) # tuple return of len 2
|
||||
|
||||
out_specs = inplace_ep.graph_signature.output_specs
|
||||
self.assertEqual(out_specs[0].arg.name, "b_state") # state
|
||||
self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
|
||||
self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
|
||||
graph_text = str(inplace_ep.graph)
|
||||
self.assertExpectedInline(
|
||||
graph_text,
|
||||
"""\
|
||||
graph():
|
||||
%b_state : [num_users=2] = placeholder[target=b_state]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\
|
||||
default](args = (%x, %b_state), kwargs = {})
|
||||
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {})
|
||||
%getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {})
|
||||
return (b_state, getitem_3, getitem_4)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_move_to_device_pass(self):
|
||||
|
@ -236,9 +236,13 @@ class Match:
|
||||
replacement graph.
|
||||
|
||||
"""
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._inductor.virtualized import NullHandler, V
|
||||
|
||||
context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext
|
||||
context = (
|
||||
V.fake_mode
|
||||
if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with context:
|
||||
if trace_fn is None:
|
||||
|
@ -5,64 +5,21 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import operator
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import (
|
||||
auto_functionalized,
|
||||
get_mutable_arg_names,
|
||||
)
|
||||
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
|
||||
from torch.export import ExportedProgram
|
||||
|
||||
|
||||
def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
|
||||
# Update every use of the HOP
|
||||
for node in reversed(auto_functionalize_nodes):
|
||||
func = node.args[0]
|
||||
original_kwargs = node.kwargs
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
|
||||
with ep.graph.inserting_before(node):
|
||||
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
|
||||
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
|
||||
# Replace auto_functionalize(func, args) with just func(args)
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
mutable_args_names = get_mutable_arg_names(new_node.target)
|
||||
|
||||
# update the users of the auto_func node (the getitem nodes)
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.target == operator.getitem
|
||||
# getitem corresponding to a mutated input, just replace all uses with the original input
|
||||
if user.args[1] >= len(func._schema.returns):
|
||||
assert user.args[1] <= len(func._schema.returns) + len(
|
||||
mutable_args_names
|
||||
)
|
||||
|
||||
# If the result of getitem was used in an output node, update the output spec with the correct name
|
||||
adjusted_index = user.args[1] - len(func._schema.returns)
|
||||
original_arg = original_kwargs[mutable_args_names[adjusted_index]]
|
||||
|
||||
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
|
||||
# of the getitem calls following the HOP.
|
||||
user.replace_all_uses_with(original_arg)
|
||||
|
||||
if len(func._schema.returns) == 1:
|
||||
# If the function has 1 return then it will just directly return the
|
||||
# result -- we don't need a getitem. So we can replace all the
|
||||
# getitem(auto_functionalized, 0) with just the note itself.
|
||||
for user in list(new_node.users.keys()):
|
||||
if user.args[1] == 0:
|
||||
user.replace_all_uses_with(new_node)
|
||||
|
||||
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
|
||||
ep.graph.erase_node(node)
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
def remove_self_clone(graph: torch.fx.Graph):
|
||||
for node in graph.nodes:
|
||||
if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def unsafe_remove_auto_functionalized_pass(
|
||||
@ -73,15 +30,20 @@ def unsafe_remove_auto_functionalized_pass(
|
||||
and modifies the calling EP inplace to have the original mutator op.
|
||||
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
|
||||
"""
|
||||
auto_functionalize_nodes: List[torch.fx.Node] = []
|
||||
for module in ep.graph_module.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target is auto_functionalized:
|
||||
auto_functionalize_nodes.append(node)
|
||||
|
||||
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
|
||||
_remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
|
||||
for module in ep.graph_module.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target is auto_functionalized:
|
||||
func = node.args[0]
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
mutable_args_names = get_mutable_arg_names(func)
|
||||
# re-inplace everything
|
||||
node.meta["only_clone_these_tensors"] = []
|
||||
decompose_auto_functionalized(ep.graph)
|
||||
remove_self_clone(ep.graph)
|
||||
ep.graph.eliminate_dead_code()
|
||||
|
||||
return ep
|
||||
|
Reference in New Issue
Block a user