mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Hide draft export under a flag (#162225)
Use `TORCH_ONNX_ENABLE_DRAFT_EXPORT` to control whether draft_export should be used as a strategy in onnx export. Follow up of https://github.com/pytorch/pytorch/pull/161454 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162225 Approved by: https://github.com/xadupre, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
adae7f66aa
commit
3771380f83
@ -43,8 +43,8 @@ def _load_boolean_flag(
|
||||
return state
|
||||
|
||||
|
||||
PLACEHOLDER: bool = _load_boolean_flag(
|
||||
"TORCH_ONNX_PLACEHOLDER",
|
||||
this_will="do nothing",
|
||||
default=True,
|
||||
ENABLE_DRAFT_EXPORT: bool = _load_boolean_flag(
|
||||
"TORCH_ONNX_ENABLE_DRAFT_EXPORT",
|
||||
this_will="enable torch.export.draft_export as a strategy for capturing models",
|
||||
default=False,
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ import pathlib
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.export import _draft_export
|
||||
from torch.onnx import _flags
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -251,7 +251,7 @@ class TorchExportDraftExportStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
ep = _draft_export.draft_export(
|
||||
ep = torch.export.draft_export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
report = ep._report # type: ignore[attr-defined]
|
||||
@ -263,24 +263,27 @@ class TorchExportDraftExportStrategy(CaptureStrategy):
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`..."
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`..."
|
||||
)
|
||||
|
||||
def _success(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ✅"
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... ✅"
|
||||
)
|
||||
|
||||
def _failure(self, model, e) -> None:
|
||||
del e # Unused
|
||||
model_repr = _take_first_line(repr(model))
|
||||
self._verbose_print(
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ❌"
|
||||
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... ❌"
|
||||
)
|
||||
|
||||
|
||||
CAPTURE_STRATEGIES = (
|
||||
CAPTURE_STRATEGIES: tuple[type[CaptureStrategy], ...] = (
|
||||
TorchExportNonStrictStrategy, # strict=False is preferred over strict=True because it does not have dynamo issues
|
||||
TorchExportStrictStrategy,
|
||||
)
|
||||
|
||||
if _flags.ENABLE_DRAFT_EXPORT:
|
||||
CAPTURE_STRATEGIES = (*CAPTURE_STRATEGIES, TorchExportDraftExportStrategy)
|
||||
|
@ -1340,6 +1340,8 @@ def export(
|
||||
export_status.torch_export_non_strict = result.success
|
||||
elif strategy_class is _capture_strategies.TorchExportStrictStrategy:
|
||||
export_status.torch_export_strict = result.success
|
||||
elif strategy_class is _capture_strategies.TorchExportDraftExportStrategy:
|
||||
export_status.torch_export_draft_export = result.success
|
||||
|
||||
if result.exception is not None:
|
||||
failed_results.append(result)
|
||||
|
@ -22,7 +22,7 @@ class ExportStatus:
|
||||
torch_export_strict: bool | None = None
|
||||
# Whether torch.export.export(..., strict=False) succeeds
|
||||
torch_export_non_strict: bool | None = None
|
||||
# Whether torch.export._draft_export.draft_export() succeeds
|
||||
# Whether torch.export.draft_export() succeeds
|
||||
torch_export_draft_export: bool | None = None
|
||||
# Whether decomposition succeeds
|
||||
decomposition: bool | None = None
|
||||
@ -47,7 +47,7 @@ def _format_export_status(status: ExportStatus) -> str:
|
||||
f"```\n"
|
||||
f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
|
||||
f"{_status_emoji(status.torch_export_strict)} Obtain model graph with `torch.export.export(..., strict=True)`\n"
|
||||
f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export._draft_export.draft_export`\n"
|
||||
f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export.draft_export`\n"
|
||||
f"{_status_emoji(status.decomposition)} Decompose operators for ONNX compatibility\n"
|
||||
f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
|
||||
f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"
|
||||
|
Reference in New Issue
Block a user