Preserve GrpahModule node stack trace after torch package deserializaion re-tracing (#155638)

Summary:
urrently the node.meta["stack_trace"] is not preserved when we torch package/load GraphModule, which means the original stack trace is lost. When we re-trace the packaged graph module, we just get a stack trace like fx-generated._0......

Adding the node.meta["stack_trace"] to torch packaged graph module

Test Plan:
```
buck2 run @//mode/dev-nosan fbcode//caffe2/test:package -- -r  TestPackageFX
```

Rollback Plan:

Differential Revision: D76379692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155638
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-06-12 03:48:24 +00:00
committed by PyTorch MergeBot
parent ce9ba071fd
commit 9f5153b1a4
2 changed files with 39 additions and 0 deletions

View File

@ -187,6 +187,27 @@ class TestPackageFX(PackageTestCase):
input = torch.rand(2, 3)
self.assertEqual(loaded_traced(input), traced(input))
def test_package_gm_preserve_stack_trace(self):
class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)
st = SimpleTest()
traced = symbolic_trace(st)
for node in traced.graph.nodes:
node.meta["stack_trace"] = f"test_{node.name}"
f = BytesIO()
with PackageExporter(f) as pe:
pe.save_pickle("model", "model.pkl", traced)
f.seek(0)
pi = PackageImporter(f)
loaded_traced = pi.load_pickle("model", "model.pkl")
for node in loaded_traced.graph.nodes:
self.assertEqual(f"test_{node.name}", node.meta["stack_trace"])
if __name__ == "__main__":
run_tests()

View File

@ -204,6 +204,14 @@ def _deserialize_graph_module(
tracer_extras = body.get("_tracer_extras", {})
graph = KeepModules().trace(com, **tracer_extras)
# Recover node.meta["stack_trace"] after re-tracing
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
if node_meta_stack_trace is not None:
del body["_graphmodule_graph_node_meta_stack_trace"]
for node in graph.nodes:
if node_meta_stack_trace.get(node.name, None) is not None:
node.meta["stack_trace"] = node_meta_stack_trace[node.name]
# Manually set Tracer class on the reconstructed Graph, to avoid
# referencing the private local subclass KeepModules.
graph._tracer_cls = tracer_cls
@ -859,6 +867,16 @@ class {module_name}(torch.nn.Module):
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
del dict_without_graph["_graph"]
# Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization
node_meta_stack_trace = {
node.name: node.meta["stack_trace"]
for node in self.graph.nodes
if "stack_trace" in node.meta
}
dict_without_graph[
"_graphmodule_graph_node_meta_stack_trace"
] = node_meta_stack_trace
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, exporter.importer)