mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838 Approved by: https://github.com/oulgen
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
# mypy: allow-untyped-defs
|
|
import warnings
|
|
|
|
|
|
import torch
|
|
import torch.fx
|
|
|
|
|
|
# TODO(ycao): This is added to avoid breaking existing code temporarily.
|
|
# Remove when migration is done.
|
|
from torch.export.graph_signature import (
|
|
ExportBackwardSignature,
|
|
ExportGraphSignature,
|
|
)
|
|
|
|
from torch.export.exported_program import (
|
|
ExportedProgram,
|
|
ModuleCallEntry,
|
|
ModuleCallSignature,
|
|
)
|
|
|
|
|
|
|
|
__all__ = [
|
|
"ExportBackwardSignature",
|
|
"ExportGraphSignature",
|
|
"ExportedProgram",
|
|
"ModuleCallEntry",
|
|
"ModuleCallSignature",
|
|
]
|
|
|
|
|
|
def _create_graph_module_for_export(root, graph):
|
|
try:
|
|
gm = torch.fx.GraphModule(root, graph)
|
|
except SyntaxError:
|
|
# If custom objects stored in memory are being used in the graph,
|
|
# the generated python code will result in a syntax error on the custom
|
|
# object, since it is unable to parse the in-memory object. However
|
|
# we can still run the graph eagerly through torch.fx.Interpreter,
|
|
# so we will bypass this error.
|
|
warnings.warn(
|
|
"Unable to execute the generated python source code from "
|
|
"the graph. The graph module will no longer be directly callable, "
|
|
"but you can still run the ExportedProgram, and if needed, you can "
|
|
"run the graph module eagerly using torch.fx.Interpreter."
|
|
)
|
|
gm = torch.fx.GraphModule(root, torch.fx.Graph())
|
|
gm._graph = graph
|
|
|
|
return gm
|