mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Handle aten.to at submodule boundaries (#153972)
Summary: #buildall Test Plan: CI Differential Revision: D74582970 When we decompose to inference IR, aten.to can sometimes disappear. As a result, export module call graph tree will start containing dead nodes because previous provenance tracking is insufficient. This PR fixes that. The caveat is that this won't work in general for tensor subclass inputs to submodule that user wants to preserve signature because we always desugar the tensor subclass into constituent tensors in inference IR making it impossible to preserve the original calling convention. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153972 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
d42c11819f
commit
370fc49dde
@ -6700,6 +6700,109 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
self.assertEqual(ep.module()(*inputs), model(*inputs))
|
||||
|
||||
def test_export_aten_to_unflatten(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
|
||||
def forward(self, x):
|
||||
to = x.to(torch.float)
|
||||
return self.bar(to).sum()
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
|
||||
ep = export(
|
||||
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
|
||||
)
|
||||
mod = ep.module()
|
||||
self.assertTrue(torch.allclose(mod(inp), Foo()(inp)))
|
||||
|
||||
@testing.expectedFailureLegacyExportNonStrict
|
||||
@testing.expectedFailureLegacyExportStrict
|
||||
@testing.expectedFailureRetraceabilityNonStrict # when we retrace, ep.module() is hierarchical
|
||||
@testing.expectedFailureRetraceability # when we retrace, ep.module() is hierarchical
|
||||
def test_export_aten_to_unflatten_subclass(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
self.param = torch.nn.Parameter(
|
||||
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
to = self.param.to(torch.float)
|
||||
return (self.bar(to).sum() + x.sum()).get_elem_a()
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "It looks like p_param is a tensor subclass."
|
||||
):
|
||||
export(
|
||||
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
|
||||
).run_decompositions({})
|
||||
|
||||
def test_export_aten_to_unflatten_subclass_pre_dispatch(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
self.param = torch.nn.Parameter(
|
||||
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
to = self.param.to(torch.float)
|
||||
return (self.bar(to).sum() + x.sum()).get_elem_a()
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
|
||||
ep = export_for_training(
|
||||
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
|
||||
)
|
||||
unflat = unflatten(ep).bar
|
||||
self.assertExpectedInline(
|
||||
str(unflat.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%_positional_arg_0 : [num_users=1] = placeholder[target=_positional_arg_0]
|
||||
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
|
||||
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (((%_positional_arg_0,), {}), %_spec_0), kwargs = {})
|
||||
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
|
||||
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
|
||||
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
|
||||
return tree_unflatten""",
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "It looks like p_param is a tensor subclass."
|
||||
):
|
||||
ep.run_decompositions()
|
||||
|
||||
def test_float_conversion(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -7,7 +7,7 @@ import functools
|
||||
import operator
|
||||
import types
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
|
||||
@ -816,25 +816,118 @@ def _common_getitem_elimination_pass(
|
||||
|
||||
|
||||
def _get_updated_module_call_graph(
|
||||
old_gm: torch.fx.GraphModule,
|
||||
old_graph_signature: ExportGraphSignature,
|
||||
gm: torch.fx.GraphModule,
|
||||
graph_signature: ExportGraphSignature,
|
||||
old_module_call_graph: list[ModuleCallEntry],
|
||||
):
|
||||
new_module_call_graph = copy.deepcopy(old_module_call_graph)
|
||||
|
||||
old_nodes = {node.name: node for node in old_gm.graph.nodes}
|
||||
|
||||
old_graph_params_buffers = {
|
||||
**old_graph_signature.inputs_to_parameters,
|
||||
**old_graph_signature.inputs_to_buffers,
|
||||
}
|
||||
new_graph_params_buffers = {
|
||||
**graph_signature.inputs_to_parameters,
|
||||
**graph_signature.inputs_to_buffers,
|
||||
}
|
||||
|
||||
# use node-level provenance metadata to create a map
|
||||
# from old node names to new node names
|
||||
provenance: dict[str, str] = {}
|
||||
|
||||
user_input_counter = 0
|
||||
old_user_input_names = [
|
||||
node.target for node in old_gm.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
old_user_input_names = list(
|
||||
filter(
|
||||
lambda x: x not in old_graph_params_buffers
|
||||
and x not in old_graph_signature.input_tokens,
|
||||
old_user_input_names,
|
||||
)
|
||||
)
|
||||
new_user_input_names = [
|
||||
node.target for node in gm.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if history := node.meta.get("from_node", []):
|
||||
provenance[history[-1].name] = node.name
|
||||
|
||||
# For params and buffers, we might have applied parameterizaiton rule
|
||||
# so that the names might have changed. But for user inputs, we know we
|
||||
# must preserve the old name.
|
||||
elif node.op == "placeholder":
|
||||
if not (
|
||||
node.target in new_graph_params_buffers
|
||||
or node.target in graph_signature.input_tokens
|
||||
):
|
||||
if node.target in new_user_input_names:
|
||||
assert isinstance(node.name, str)
|
||||
old_name = old_user_input_names[user_input_counter]
|
||||
assert isinstance(old_name, str)
|
||||
provenance[old_name] = node.name
|
||||
user_input_counter += 1
|
||||
|
||||
# For all the parameters and buffers, we first see
|
||||
# if they are result of paramerizaitons and if they
|
||||
# are, we log them and error later
|
||||
old_param_to_desugared = defaultdict(list)
|
||||
for name, target in new_graph_params_buffers.items():
|
||||
# if the parameters are not parametrized, the naming won't change.
|
||||
if not target.startswith("parametrizations."):
|
||||
# If we are in strict mode, we can't just reuse the param names
|
||||
if name in old_graph_params_buffers:
|
||||
provenance[name] = name
|
||||
else:
|
||||
old_target = ".".join(target.split(".")[1:-1])
|
||||
old_param_to_desugared[old_target].append(name)
|
||||
|
||||
# map old names to new names in module call signatures
|
||||
for entry in new_module_call_graph:
|
||||
signature = entry.signature
|
||||
if signature is None:
|
||||
continue
|
||||
for x in [*signature.inputs, *signature.outputs]:
|
||||
x.name = provenance.get(x.name, x.name)
|
||||
# We noticed that submodule is taking subclass as input. we can't
|
||||
# preserve signature here.
|
||||
if x.name in old_param_to_desugared:
|
||||
raise ValueError(
|
||||
f"It looks like {x.name} is a tensor subclass. "
|
||||
f"Preserving submodule that takes subclass parameter is not supported"
|
||||
f" in inference IR because we desugar them, resulting in more tensors"
|
||||
)
|
||||
|
||||
if x.name in provenance:
|
||||
x.name = provenance[x.name]
|
||||
|
||||
# This can happen when aten.to is called at graph boundaries.
|
||||
# Basically aten.to at post-dispatch level can either be copy
|
||||
# or alias. In the alias case, we will no-op it so it will
|
||||
# disappear from the graph. If we detect such case, we should
|
||||
# reuse the input to aten.to as the new input to the submodule.
|
||||
# Technically this can happen for other maybe aliasing ops,
|
||||
# but aten.to is probably the most common one.
|
||||
elif x.name in old_nodes:
|
||||
old_node = old_nodes[x.name]
|
||||
if old_node.op == "call_function" and old_node.target in [
|
||||
torch.ops.aten.to.dtype_layout,
|
||||
torch.ops.aten.to.device,
|
||||
torch.ops.aten.to.dtype,
|
||||
]:
|
||||
old_target = old_node.args[0].name
|
||||
if old_target not in provenance:
|
||||
raise ValueError(
|
||||
f"It looks like {old_target} is a tensor subclass. "
|
||||
f"Preserving submodule that takes subclass parameter is not supported"
|
||||
f" in inference IR because we desugar them, resulting in more tensors"
|
||||
)
|
||||
|
||||
x.name = provenance[old_target]
|
||||
|
||||
return new_module_call_graph
|
||||
|
||||
@ -864,7 +957,10 @@ def _decompose_exported_program(
|
||||
# new nodes due to decompositions. So we need to update these signatures
|
||||
# in the decomposed exported program's module_call_graph.
|
||||
new_module_call_graph = _get_updated_module_call_graph(
|
||||
ep.graph_module,
|
||||
ep.graph_signature,
|
||||
gm,
|
||||
new_graph_signature,
|
||||
ep.module_call_graph,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user