Revert "Add node meta value into UnflattenedModule (#117686)"

This reverts commit cbf24ba962f72175ec1c71a25f3379f7d9149ec1.

Reverted https://github.com/pytorch/pytorch/pull/117686 on behalf of https://github.com/PaliC due to breaks internal modeling tests ([comment](https://github.com/pytorch/pytorch/pull/117686#issuecomment-1898939899))
This commit is contained in:
PyTorch MergeBot
2024-01-18 17:46:38 +00:00
parent 5aa895e53e
commit 7451dd0585
2 changed files with 0 additions and 39 deletions

View File

@ -456,38 +456,6 @@ class TestUnflatten(TestCase):
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
self.assertTrue(torch.allclose(ep(*inp), mod(*inp)))
def test_unflattened_module_nodes_has_meta_val(self):
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + x, x * x
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + sum(self.submod(x))
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def check_meta(gm):
for n in gm.graph.nodes:
if n.op == "output":
continue
self.assertTrue(n.meta.get("val") is not None)
for m in unflattened.modules():
check_meta(m)
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
class TransposeModule(torch.nn.Module):
def __init__(self):

View File

@ -605,19 +605,12 @@ class _ModuleFrame:
if parent_out is None:
return
parent_out.meta["val"] = (
graph_outputs.meta.get("val")
if isinstance(graph_outputs, torch.fx.Node)
else [o.meta.get("val") for o in graph_outputs]
)
if len(orig_outputs) == 1 and signature is None:
self.parent.node_map[orig_outputs[0]] = parent_out
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
proxy_out.meta["val"] = copy.copy(orig_output.meta["val"])
self.parent.node_map[orig_output] = proxy_out
if self.cached_graph_module is not None: