mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Preserve GrpahModule node stack trace after torch package deserializaion re-tracing (#155638)
Summary: urrently the node.meta["stack_trace"] is not preserved when we torch package/load GraphModule, which means the original stack trace is lost. When we re-trace the packaged graph module, we just get a stack trace like fx-generated._0...... Adding the node.meta["stack_trace"] to torch packaged graph module Test Plan: ``` buck2 run @//mode/dev-nosan fbcode//caffe2/test:package -- -r TestPackageFX ``` Rollback Plan: Differential Revision: D76379692 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155638 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce9ba071fd
commit
9f5153b1a4
@ -187,6 +187,27 @@ class TestPackageFX(PackageTestCase):
|
|||||||
input = torch.rand(2, 3)
|
input = torch.rand(2, 3)
|
||||||
self.assertEqual(loaded_traced(input), traced(input))
|
self.assertEqual(loaded_traced(input), traced(input))
|
||||||
|
|
||||||
|
def test_package_gm_preserve_stack_trace(self):
|
||||||
|
class SimpleTest(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.relu(x + 3.0)
|
||||||
|
|
||||||
|
st = SimpleTest()
|
||||||
|
traced = symbolic_trace(st)
|
||||||
|
|
||||||
|
for node in traced.graph.nodes:
|
||||||
|
node.meta["stack_trace"] = f"test_{node.name}"
|
||||||
|
|
||||||
|
f = BytesIO()
|
||||||
|
with PackageExporter(f) as pe:
|
||||||
|
pe.save_pickle("model", "model.pkl", traced)
|
||||||
|
|
||||||
|
f.seek(0)
|
||||||
|
pi = PackageImporter(f)
|
||||||
|
loaded_traced = pi.load_pickle("model", "model.pkl")
|
||||||
|
for node in loaded_traced.graph.nodes:
|
||||||
|
self.assertEqual(f"test_{node.name}", node.meta["stack_trace"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -204,6 +204,14 @@ def _deserialize_graph_module(
|
|||||||
tracer_extras = body.get("_tracer_extras", {})
|
tracer_extras = body.get("_tracer_extras", {})
|
||||||
graph = KeepModules().trace(com, **tracer_extras)
|
graph = KeepModules().trace(com, **tracer_extras)
|
||||||
|
|
||||||
|
# Recover node.meta["stack_trace"] after re-tracing
|
||||||
|
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
|
||||||
|
if node_meta_stack_trace is not None:
|
||||||
|
del body["_graphmodule_graph_node_meta_stack_trace"]
|
||||||
|
for node in graph.nodes:
|
||||||
|
if node_meta_stack_trace.get(node.name, None) is not None:
|
||||||
|
node.meta["stack_trace"] = node_meta_stack_trace[node.name]
|
||||||
|
|
||||||
# Manually set Tracer class on the reconstructed Graph, to avoid
|
# Manually set Tracer class on the reconstructed Graph, to avoid
|
||||||
# referencing the private local subclass KeepModules.
|
# referencing the private local subclass KeepModules.
|
||||||
graph._tracer_cls = tracer_cls
|
graph._tracer_cls = tracer_cls
|
||||||
@ -859,6 +867,16 @@ class {module_name}(torch.nn.Module):
|
|||||||
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
||||||
del dict_without_graph["_graph"]
|
del dict_without_graph["_graph"]
|
||||||
|
|
||||||
|
# Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization
|
||||||
|
node_meta_stack_trace = {
|
||||||
|
node.name: node.meta["stack_trace"]
|
||||||
|
for node in self.graph.nodes
|
||||||
|
if "stack_trace" in node.meta
|
||||||
|
}
|
||||||
|
dict_without_graph[
|
||||||
|
"_graphmodule_graph_node_meta_stack_trace"
|
||||||
|
] = node_meta_stack_trace
|
||||||
|
|
||||||
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
||||||
python_code = self.recompile()
|
python_code = self.recompile()
|
||||||
import_block = _format_import_block(python_code.globals, exporter.importer)
|
import_block = _format_import_block(python_code.globals, exporter.importer)
|
||||||
|
|||||||
Reference in New Issue
Block a user