Files
pytorch/torch/_export/exported_program.py
2024-06-08 18:16:33 +00:00

52 lines
1.4 KiB
Python

# mypy: allow-untyped-defs
import warnings
import torch
import torch.fx
# TODO(ycao): This is added to avoid breaking existing code temporarily.
# Remove when migration is done.
from torch.export.graph_signature import (
ExportBackwardSignature,
ExportGraphSignature,
)
from torch.export.exported_program import (
ExportedProgram,
ModuleCallEntry,
ModuleCallSignature,
)
__all__ = [
"ExportBackwardSignature",
"ExportGraphSignature",
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
]
def _create_graph_module_for_export(root, graph):
try:
gm = torch.fx.GraphModule(root, graph)
except SyntaxError:
# If custom objects stored in memory are being used in the graph,
# the generated python code will result in a syntax error on the custom
# object, since it is unable to parse the in-memory object. However
# we can still run the graph eagerly through torch.fx.Interpreter,
# so we will bypass this error.
warnings.warn(
"Unable to execute the generated python source code from "
"the graph. The graph module will no longer be directly callable, "
"but you can still run the ExportedProgram, and if needed, you can "
"run the graph module eagerly using torch.fx.Interpreter."
)
gm = torch.fx.GraphModule(root, torch.fx.Graph())
gm._graph = graph
return gm