Handle aten.to at submodule boundaries (#153972)

Summary: #buildall

Test Plan: CI

Differential Revision: D74582970

When we decompose to inference IR, aten.to can sometimes disappear. As a result, export module call graph tree will start containing dead nodes because previous provenance tracking is insufficient. This PR fixes that. The caveat is that this won't work in general for tensor subclass inputs to submodule that user wants to preserve signature because we always desugar the tensor subclass into constituent tensors in inference IR making it impossible to preserve the original calling convention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153972
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2025-06-14 16:13:27 +00:00
committed by PyTorch MergeBot
parent d42c11819f
commit 370fc49dde
2 changed files with 201 additions and 2 deletions

View File

@ -6700,6 +6700,109 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertEqual(ep.module()(*inputs), model(*inputs))
def test_export_aten_to_unflatten(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
def forward(self, x):
to = x.to(torch.float)
return self.bar(to).sum()
inp = torch.randn(4, 4)
ep = export(
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
)
mod = ep.module()
self.assertTrue(torch.allclose(mod(inp), Foo()(inp)))
@testing.expectedFailureLegacyExportNonStrict
@testing.expectedFailureLegacyExportStrict
@testing.expectedFailureRetraceabilityNonStrict # when we retrace, ep.module() is hierarchical
@testing.expectedFailureRetraceability # when we retrace, ep.module() is hierarchical
def test_export_aten_to_unflatten_subclass(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
self.param = torch.nn.Parameter(
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
)
def forward(self, x):
to = self.param.to(torch.float)
return (self.bar(to).sum() + x.sum()).get_elem_a()
inp = torch.randn(4, 4)
with self.assertRaisesRegex(
ValueError, "It looks like p_param is a tensor subclass."
):
export(
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
).run_decompositions({})
def test_export_aten_to_unflatten_subclass_pre_dispatch(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
self.param = torch.nn.Parameter(
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
)
def forward(self, x):
to = self.param.to(torch.float)
return (self.bar(to).sum() + x.sum()).get_elem_a()
inp = torch.randn(4, 4)
ep = export_for_training(
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
)
unflat = unflatten(ep).bar
self.assertExpectedInline(
str(unflat.graph).strip(),
"""\
graph():
%_positional_arg_0 : [num_users=1] = placeholder[target=_positional_arg_0]
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (((%_positional_arg_0,), {}), %_spec_0), kwargs = {})
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
return tree_unflatten""",
)
with self.assertRaisesRegex(
ValueError, "It looks like p_param is a tensor subclass."
):
ep.run_decompositions()
def test_float_conversion(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -7,7 +7,7 @@ import functools
import operator
import types
import warnings
from collections import namedtuple
from collections import defaultdict, namedtuple
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
@ -816,25 +816,118 @@ def _common_getitem_elimination_pass(
def _get_updated_module_call_graph(
old_gm: torch.fx.GraphModule,
old_graph_signature: ExportGraphSignature,
gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
old_module_call_graph: list[ModuleCallEntry],
):
new_module_call_graph = copy.deepcopy(old_module_call_graph)
old_nodes = {node.name: node for node in old_gm.graph.nodes}
old_graph_params_buffers = {
**old_graph_signature.inputs_to_parameters,
**old_graph_signature.inputs_to_buffers,
}
new_graph_params_buffers = {
**graph_signature.inputs_to_parameters,
**graph_signature.inputs_to_buffers,
}
# use node-level provenance metadata to create a map
# from old node names to new node names
provenance: dict[str, str] = {}
user_input_counter = 0
old_user_input_names = [
node.target for node in old_gm.graph.nodes if node.op == "placeholder"
]
old_user_input_names = list(
filter(
lambda x: x not in old_graph_params_buffers
and x not in old_graph_signature.input_tokens,
old_user_input_names,
)
)
new_user_input_names = [
node.target for node in gm.graph.nodes if node.op == "placeholder"
]
for node in gm.graph.nodes:
if history := node.meta.get("from_node", []):
provenance[history[-1].name] = node.name
# For params and buffers, we might have applied parameterizaiton rule
# so that the names might have changed. But for user inputs, we know we
# must preserve the old name.
elif node.op == "placeholder":
if not (
node.target in new_graph_params_buffers
or node.target in graph_signature.input_tokens
):
if node.target in new_user_input_names:
assert isinstance(node.name, str)
old_name = old_user_input_names[user_input_counter]
assert isinstance(old_name, str)
provenance[old_name] = node.name
user_input_counter += 1
# For all the parameters and buffers, we first see
# if they are result of paramerizaitons and if they
# are, we log them and error later
old_param_to_desugared = defaultdict(list)
for name, target in new_graph_params_buffers.items():
# if the parameters are not parametrized, the naming won't change.
if not target.startswith("parametrizations."):
# If we are in strict mode, we can't just reuse the param names
if name in old_graph_params_buffers:
provenance[name] = name
else:
old_target = ".".join(target.split(".")[1:-1])
old_param_to_desugared[old_target].append(name)
# map old names to new names in module call signatures
for entry in new_module_call_graph:
signature = entry.signature
if signature is None:
continue
for x in [*signature.inputs, *signature.outputs]:
x.name = provenance.get(x.name, x.name)
# We noticed that submodule is taking subclass as input. we can't
# preserve signature here.
if x.name in old_param_to_desugared:
raise ValueError(
f"It looks like {x.name} is a tensor subclass. "
f"Preserving submodule that takes subclass parameter is not supported"
f" in inference IR because we desugar them, resulting in more tensors"
)
if x.name in provenance:
x.name = provenance[x.name]
# This can happen when aten.to is called at graph boundaries.
# Basically aten.to at post-dispatch level can either be copy
# or alias. In the alias case, we will no-op it so it will
# disappear from the graph. If we detect such case, we should
# reuse the input to aten.to as the new input to the submodule.
# Technically this can happen for other maybe aliasing ops,
# but aten.to is probably the most common one.
elif x.name in old_nodes:
old_node = old_nodes[x.name]
if old_node.op == "call_function" and old_node.target in [
torch.ops.aten.to.dtype_layout,
torch.ops.aten.to.device,
torch.ops.aten.to.dtype,
]:
old_target = old_node.args[0].name
if old_target not in provenance:
raise ValueError(
f"It looks like {old_target} is a tensor subclass. "
f"Preserving submodule that takes subclass parameter is not supported"
f" in inference IR because we desugar them, resulting in more tensors"
)
x.name = provenance[old_target]
return new_module_call_graph
@ -864,7 +957,10 @@ def _decompose_exported_program(
# new nodes due to decompositions. So we need to update these signatures
# in the decomposed exported program's module_call_graph.
new_module_call_graph = _get_updated_module_call_graph(
ep.graph_module,
ep.graph_signature,
gm,
new_graph_signature,
ep.module_call_graph,
)