mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] Make draft_export public (#153219)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/153219 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
b15b870903
commit
d51bc27378
@ -105,8 +105,7 @@ To call ``draft-export``, we can replace the ``torch.export`` line with the foll
|
||||
|
||||
::
|
||||
|
||||
from torch.export._draft_export import draft_export
|
||||
ep = draft_export(M(), inp)
|
||||
ep = torch.export.draft_export(M(), inp)
|
||||
|
||||
``ep`` is a valid ExportedProgram which can now be passed through further environments!
|
||||
|
||||
|
@ -790,10 +790,9 @@ API Reference
|
||||
.. autofunction:: export
|
||||
.. autofunction:: save
|
||||
.. autofunction:: load
|
||||
.. autofunction:: draft_export
|
||||
.. autofunction:: register_dataclass
|
||||
.. autoclass:: torch.export.dynamic_shapes.Dim
|
||||
.. autofunction:: torch.export.exported_program.default_decompositions
|
||||
.. autofunction:: dims
|
||||
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
|
||||
|
||||
.. automethod:: dynamic_shapes
|
||||
@ -805,22 +804,21 @@ API Reference
|
||||
.. automethod:: verify
|
||||
|
||||
.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
|
||||
.. autoclass:: Constraint
|
||||
.. autoclass:: ExportedProgram
|
||||
|
||||
.. attribute:: graph
|
||||
.. attribute:: graph_signature
|
||||
.. attribute:: state_dict
|
||||
.. attribute:: constants
|
||||
.. attribute:: range_constraints
|
||||
.. attribute:: module_call_graph
|
||||
.. attribute:: example_inputs
|
||||
.. automethod:: module
|
||||
.. automethod:: buffers
|
||||
.. automethod:: named_buffers
|
||||
.. automethod:: parameters
|
||||
.. automethod:: named_parameters
|
||||
.. automethod:: run_decompositions
|
||||
|
||||
.. autoclass:: ExportBackwardSignature
|
||||
.. autoclass:: ExportGraphSignature
|
||||
.. autoclass:: ModuleCallSignature
|
||||
.. autoclass:: ModuleCallEntry
|
||||
|
||||
|
||||
.. automodule:: torch.export.decomp_utils
|
||||
.. autoclass:: CustomDecompTable
|
||||
|
||||
@ -830,9 +828,16 @@ API Reference
|
||||
.. automethod:: materialize
|
||||
.. automethod:: pop
|
||||
.. automethod:: update
|
||||
.. autofunction:: torch.export.exported_program.default_decompositions
|
||||
|
||||
.. automodule:: torch.export.exported_program
|
||||
.. automodule:: torch.export.graph_signature
|
||||
.. autoclass:: ExportGraphSignature
|
||||
|
||||
.. automethod:: replace_all_uses
|
||||
.. automethod:: get_replace_hook
|
||||
|
||||
.. autoclass:: ExportBackwardSignature
|
||||
.. autoclass:: InputKind
|
||||
.. autoclass:: InputSpec
|
||||
.. autoclass:: OutputKind
|
||||
@ -840,12 +845,8 @@ API Reference
|
||||
.. autoclass:: SymIntArgument
|
||||
.. autoclass:: SymBoolArgument
|
||||
.. autoclass:: SymFloatArgument
|
||||
.. autoclass:: ExportGraphSignature
|
||||
|
||||
.. automethod:: replace_all_uses
|
||||
.. automethod:: get_replace_hook
|
||||
|
||||
.. autoclass:: torch.export.graph_signature.CustomObjArgument
|
||||
.. autoclass:: CustomObjArgument
|
||||
|
||||
.. py:module:: torch.export.dynamic_shapes
|
||||
.. py:module:: torch.export.custom_ops
|
||||
|
@ -5,8 +5,8 @@ import unittest
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Dim, export
|
||||
from torch.export._draft_export import draft_export, FailureType
|
||||
from torch.export import Dim, draft_export, export
|
||||
from torch.export._draft_export import FailureType
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
||||
|
@ -52,6 +52,7 @@ __all__ = [
|
||||
"FlatArgsAdapter",
|
||||
"UnflattenedModule",
|
||||
"AdditionalInputs",
|
||||
"draft_export",
|
||||
]
|
||||
|
||||
# To make sure export specific custom ops are loaded
|
||||
@ -518,6 +519,32 @@ def load(
|
||||
return ep
|
||||
|
||||
|
||||
def draft_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
A version of torch.export.export which is designed to consistently produce
|
||||
an ExportedProgram, even if there are potential soundness issues, and to
|
||||
generate a report listing the issues found.
|
||||
"""
|
||||
from ._draft_export import draft_export
|
||||
|
||||
return draft_export(
|
||||
mod=mod,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
def register_dataclass(
|
||||
cls: type[Any],
|
||||
*,
|
||||
|
@ -17,9 +17,10 @@ from torch._export.passes.insert_custom_op_guards import (
|
||||
insert_custom_op_guards,
|
||||
OpProfile,
|
||||
)
|
||||
from torch.export import ExportedProgram
|
||||
from torch.export._trace import _export
|
||||
from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim
|
||||
|
||||
from ._trace import _export
|
||||
from .dynamic_shapes import _DimHint, _DimHintType, Dim
|
||||
from .exported_program import ExportedProgram
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
Reference in New Issue
Block a user