[export] Make draft_export public (#153219)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153219
Approved by: https://github.com/pianpwk
This commit is contained in:
angelayi
2025-05-14 02:18:32 +00:00
committed by PyTorch MergeBot
parent b15b870903
commit d51bc27378
5 changed files with 50 additions and 22 deletions

View File

@ -105,8 +105,7 @@ To call ``draft-export``, we can replace the ``torch.export`` line with the foll
::
from torch.export._draft_export import draft_export
ep = draft_export(M(), inp)
ep = torch.export.draft_export(M(), inp)
``ep`` is a valid ExportedProgram which can now be passed through further environments!

View File

@ -790,10 +790,9 @@ API Reference
.. autofunction:: export
.. autofunction:: save
.. autofunction:: load
.. autofunction:: draft_export
.. autofunction:: register_dataclass
.. autoclass:: torch.export.dynamic_shapes.Dim
.. autofunction:: torch.export.exported_program.default_decompositions
.. autofunction:: dims
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
.. automethod:: dynamic_shapes
@ -805,22 +804,21 @@ API Reference
.. automethod:: verify
.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
.. autoclass:: Constraint
.. autoclass:: ExportedProgram
.. attribute:: graph
.. attribute:: graph_signature
.. attribute:: state_dict
.. attribute:: constants
.. attribute:: range_constraints
.. attribute:: module_call_graph
.. attribute:: example_inputs
.. automethod:: module
.. automethod:: buffers
.. automethod:: named_buffers
.. automethod:: parameters
.. automethod:: named_parameters
.. automethod:: run_decompositions
.. autoclass:: ExportBackwardSignature
.. autoclass:: ExportGraphSignature
.. autoclass:: ModuleCallSignature
.. autoclass:: ModuleCallEntry
.. automodule:: torch.export.decomp_utils
.. autoclass:: CustomDecompTable
@ -830,9 +828,16 @@ API Reference
.. automethod:: materialize
.. automethod:: pop
.. automethod:: update
.. autofunction:: torch.export.exported_program.default_decompositions
.. automodule:: torch.export.exported_program
.. automodule:: torch.export.graph_signature
.. autoclass:: ExportGraphSignature
.. automethod:: replace_all_uses
.. automethod:: get_replace_hook
.. autoclass:: ExportBackwardSignature
.. autoclass:: InputKind
.. autoclass:: InputSpec
.. autoclass:: OutputKind
@ -840,12 +845,8 @@ API Reference
.. autoclass:: SymIntArgument
.. autoclass:: SymBoolArgument
.. autoclass:: SymFloatArgument
.. autoclass:: ExportGraphSignature
.. automethod:: replace_all_uses
.. automethod:: get_replace_hook
.. autoclass:: torch.export.graph_signature.CustomObjArgument
.. autoclass:: CustomObjArgument
.. py:module:: torch.export.dynamic_shapes
.. py:module:: torch.export.custom_ops

View File

@ -5,8 +5,8 @@ import unittest
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Dim, export
from torch.export._draft_export import draft_export, FailureType
from torch.export import Dim, draft_export, export
from torch.export._draft_export import FailureType
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase

View File

@ -52,6 +52,7 @@ __all__ = [
"FlatArgsAdapter",
"UnflattenedModule",
"AdditionalInputs",
"draft_export",
]
# To make sure export specific custom ops are loaded
@ -518,6 +519,32 @@ def load(
return ep
def draft_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
) -> ExportedProgram:
"""
A version of torch.export.export which is designed to consistently produce
an ExportedProgram, even if there are potential soundness issues, and to
generate a report listing the issues found.
"""
from ._draft_export import draft_export
return draft_export(
mod=mod,
args=args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
strict=strict,
)
def register_dataclass(
cls: type[Any],
*,

View File

@ -17,9 +17,10 @@ from torch._export.passes.insert_custom_op_guards import (
insert_custom_op_guards,
OpProfile,
)
from torch.export import ExportedProgram
from torch.export._trace import _export
from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim
from ._trace import _export
from .dynamic_shapes import _DimHint, _DimHintType, Dim
from .exported_program import ExportedProgram
log = logging.getLogger(__name__)