mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Preserve GrpahModule node stack trace after torch package deserializaion re-tracing (#155638)
Summary: urrently the node.meta["stack_trace"] is not preserved when we torch package/load GraphModule, which means the original stack trace is lost. When we re-trace the packaged graph module, we just get a stack trace like fx-generated._0...... Adding the node.meta["stack_trace"] to torch packaged graph module Test Plan: ``` buck2 run @//mode/dev-nosan fbcode//caffe2/test:package -- -r TestPackageFX ``` Rollback Plan: Differential Revision: D76379692 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155638 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce9ba071fd
commit
9f5153b1a4
@ -187,6 +187,27 @@ class TestPackageFX(PackageTestCase):
|
||||
input = torch.rand(2, 3)
|
||||
self.assertEqual(loaded_traced(input), traced(input))
|
||||
|
||||
def test_package_gm_preserve_stack_trace(self):
|
||||
class SimpleTest(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.relu(x + 3.0)
|
||||
|
||||
st = SimpleTest()
|
||||
traced = symbolic_trace(st)
|
||||
|
||||
for node in traced.graph.nodes:
|
||||
node.meta["stack_trace"] = f"test_{node.name}"
|
||||
|
||||
f = BytesIO()
|
||||
with PackageExporter(f) as pe:
|
||||
pe.save_pickle("model", "model.pkl", traced)
|
||||
|
||||
f.seek(0)
|
||||
pi = PackageImporter(f)
|
||||
loaded_traced = pi.load_pickle("model", "model.pkl")
|
||||
for node in loaded_traced.graph.nodes:
|
||||
self.assertEqual(f"test_{node.name}", node.meta["stack_trace"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -204,6 +204,14 @@ def _deserialize_graph_module(
|
||||
tracer_extras = body.get("_tracer_extras", {})
|
||||
graph = KeepModules().trace(com, **tracer_extras)
|
||||
|
||||
# Recover node.meta["stack_trace"] after re-tracing
|
||||
node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None)
|
||||
if node_meta_stack_trace is not None:
|
||||
del body["_graphmodule_graph_node_meta_stack_trace"]
|
||||
for node in graph.nodes:
|
||||
if node_meta_stack_trace.get(node.name, None) is not None:
|
||||
node.meta["stack_trace"] = node_meta_stack_trace[node.name]
|
||||
|
||||
# Manually set Tracer class on the reconstructed Graph, to avoid
|
||||
# referencing the private local subclass KeepModules.
|
||||
graph._tracer_cls = tracer_cls
|
||||
@ -859,6 +867,16 @@ class {module_name}(torch.nn.Module):
|
||||
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
||||
del dict_without_graph["_graph"]
|
||||
|
||||
# Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization
|
||||
node_meta_stack_trace = {
|
||||
node.name: node.meta["stack_trace"]
|
||||
for node in self.graph.nodes
|
||||
if "stack_trace" in node.meta
|
||||
}
|
||||
dict_without_graph[
|
||||
"_graphmodule_graph_node_meta_stack_trace"
|
||||
] = node_meta_stack_trace
|
||||
|
||||
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
||||
python_code = self.recompile()
|
||||
import_block = _format_import_block(python_code.globals, exporter.importer)
|
||||
|
Reference in New Issue
Block a user