diff --git a/test/export/test_export.py b/test/export/test_export.py index 4e50720cdcce..fc79e1f9f610 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6700,6 +6700,109 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): self.assertEqual(ep.module()(*inputs), model(*inputs)) + def test_export_aten_to_unflatten(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + + def forward(self, x): + to = x.to(torch.float) + return self.bar(to).sum() + + inp = torch.randn(4, 4) + + ep = export( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ) + mod = ep.module() + self.assertTrue(torch.allclose(mod(inp), Foo()(inp))) + + @testing.expectedFailureLegacyExportNonStrict + @testing.expectedFailureLegacyExportStrict + @testing.expectedFailureRetraceabilityNonStrict # when we retrace, ep.module() is hierarchical + @testing.expectedFailureRetraceability # when we retrace, ep.module() is hierarchical + def test_export_aten_to_unflatten_subclass(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + self.param = torch.nn.Parameter( + TwoTensor(torch.ones(4, 4), torch.ones(4, 4)) + ) + + def forward(self, x): + to = self.param.to(torch.float) + return (self.bar(to).sum() + x.sum()).get_elem_a() + + inp = torch.randn(4, 4) + + with self.assertRaisesRegex( + ValueError, "It looks like p_param is a tensor subclass." + ): + export( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ).run_decompositions({}) + + def test_export_aten_to_unflatten_subclass_pre_dispatch(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.sum() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.bar = Bar() + self.param = torch.nn.Parameter( + TwoTensor(torch.ones(4, 4), torch.ones(4, 4)) + ) + + def forward(self, x): + to = self.param.to(torch.float) + return (self.bar(to).sum() + x.sum()).get_elem_a() + + inp = torch.randn(4, 4) + + ep = export_for_training( + Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",) + ) + unflat = unflatten(ep).bar + self.assertExpectedInline( + str(unflat.graph).strip(), + """\ +graph(): + %_positional_arg_0 : [num_users=1] = placeholder[target=_positional_arg_0] + %_spec_0 : [num_users=1] = get_attr[target=_spec_0] + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (((%_positional_arg_0,), {}), %_spec_0), kwargs = {}) + %to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {}) + %_spec_1 : [num_users=1] = get_attr[target=_spec_1] + %tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {}) + return tree_unflatten""", + ) + + with self.assertRaisesRegex( + ValueError, "It looks like p_param is a tensor subclass." + ): + ep.run_decompositions() + def test_float_conversion(self): class Module(torch.nn.Module): def forward(self, x): diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 3a9913539231..8d71bfcfd89c 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -7,7 +7,7 @@ import functools import operator import types import warnings -from collections import namedtuple +from collections import defaultdict, namedtuple from collections.abc import Iterator from contextlib import contextmanager from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union @@ -816,25 +816,118 @@ def _common_getitem_elimination_pass( def _get_updated_module_call_graph( + old_gm: torch.fx.GraphModule, + old_graph_signature: ExportGraphSignature, gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, old_module_call_graph: list[ModuleCallEntry], ): new_module_call_graph = copy.deepcopy(old_module_call_graph) + old_nodes = {node.name: node for node in old_gm.graph.nodes} + + old_graph_params_buffers = { + **old_graph_signature.inputs_to_parameters, + **old_graph_signature.inputs_to_buffers, + } + new_graph_params_buffers = { + **graph_signature.inputs_to_parameters, + **graph_signature.inputs_to_buffers, + } + # use node-level provenance metadata to create a map # from old node names to new node names provenance: dict[str, str] = {} + + user_input_counter = 0 + old_user_input_names = [ + node.target for node in old_gm.graph.nodes if node.op == "placeholder" + ] + old_user_input_names = list( + filter( + lambda x: x not in old_graph_params_buffers + and x not in old_graph_signature.input_tokens, + old_user_input_names, + ) + ) + new_user_input_names = [ + node.target for node in gm.graph.nodes if node.op == "placeholder" + ] + for node in gm.graph.nodes: if history := node.meta.get("from_node", []): provenance[history[-1].name] = node.name + # For params and buffers, we might have applied parameterizaiton rule + # so that the names might have changed. But for user inputs, we know we + # must preserve the old name. + elif node.op == "placeholder": + if not ( + node.target in new_graph_params_buffers + or node.target in graph_signature.input_tokens + ): + if node.target in new_user_input_names: + assert isinstance(node.name, str) + old_name = old_user_input_names[user_input_counter] + assert isinstance(old_name, str) + provenance[old_name] = node.name + user_input_counter += 1 + + # For all the parameters and buffers, we first see + # if they are result of paramerizaitons and if they + # are, we log them and error later + old_param_to_desugared = defaultdict(list) + for name, target in new_graph_params_buffers.items(): + # if the parameters are not parametrized, the naming won't change. + if not target.startswith("parametrizations."): + # If we are in strict mode, we can't just reuse the param names + if name in old_graph_params_buffers: + provenance[name] = name + else: + old_target = ".".join(target.split(".")[1:-1]) + old_param_to_desugared[old_target].append(name) + # map old names to new names in module call signatures for entry in new_module_call_graph: signature = entry.signature if signature is None: continue for x in [*signature.inputs, *signature.outputs]: - x.name = provenance.get(x.name, x.name) + # We noticed that submodule is taking subclass as input. we can't + # preserve signature here. + if x.name in old_param_to_desugared: + raise ValueError( + f"It looks like {x.name} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + if x.name in provenance: + x.name = provenance[x.name] + + # This can happen when aten.to is called at graph boundaries. + # Basically aten.to at post-dispatch level can either be copy + # or alias. In the alias case, we will no-op it so it will + # disappear from the graph. If we detect such case, we should + # reuse the input to aten.to as the new input to the submodule. + # Technically this can happen for other maybe aliasing ops, + # but aten.to is probably the most common one. + elif x.name in old_nodes: + old_node = old_nodes[x.name] + if old_node.op == "call_function" and old_node.target in [ + torch.ops.aten.to.dtype_layout, + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + ]: + old_target = old_node.args[0].name + if old_target not in provenance: + raise ValueError( + f"It looks like {old_target} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + x.name = provenance[old_target] return new_module_call_graph @@ -864,7 +957,10 @@ def _decompose_exported_program( # new nodes due to decompositions. So we need to update these signatures # in the decomposed exported program's module_call_graph. new_module_call_graph = _get_updated_module_call_graph( + ep.graph_module, + ep.graph_signature, gm, + new_graph_signature, ep.module_call_graph, )