mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add node meta value into UnflattenedModule (#117686)"
This reverts commit cbf24ba962f72175ec1c71a25f3379f7d9149ec1. Reverted https://github.com/pytorch/pytorch/pull/117686 on behalf of https://github.com/PaliC due to breaks internal modeling tests ([comment](https://github.com/pytorch/pytorch/pull/117686#issuecomment-1898939899))
This commit is contained in:
@ -456,38 +456,6 @@ class TestUnflatten(TestCase):
|
||||
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
|
||||
self.assertTrue(torch.allclose(ep(*inp), mod(*inp)))
|
||||
|
||||
def test_unflattened_module_nodes_has_meta_val(self):
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x + x, x * x
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
def forward(self, x):
|
||||
return x + sum(self.submod(x))
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
inputs = (torch.rand(2, 3),)
|
||||
self.compare_outputs(orig_eager, unflattened, inputs)
|
||||
|
||||
def check_meta(gm):
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == "output":
|
||||
continue
|
||||
self.assertTrue(n.meta.get("val") is not None)
|
||||
|
||||
for m in unflattened.modules():
|
||||
check_meta(m)
|
||||
|
||||
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
|
||||
class TransposeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -605,19 +605,12 @@ class _ModuleFrame:
|
||||
if parent_out is None:
|
||||
return
|
||||
|
||||
parent_out.meta["val"] = (
|
||||
graph_outputs.meta.get("val")
|
||||
if isinstance(graph_outputs, torch.fx.Node)
|
||||
else [o.meta.get("val") for o in graph_outputs]
|
||||
)
|
||||
|
||||
if len(orig_outputs) == 1 and signature is None:
|
||||
self.parent.node_map[orig_outputs[0]] = parent_out
|
||||
else:
|
||||
for i, orig_output in enumerate(orig_outputs):
|
||||
# Use Proxy to record getitem access.
|
||||
proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index]
|
||||
proxy_out.meta["val"] = copy.copy(orig_output.meta["val"])
|
||||
self.parent.node_map[orig_output] = proxy_out
|
||||
|
||||
if self.cached_graph_module is not None:
|
||||
|
Reference in New Issue
Block a user