[ONNX] Hide draft export under a flag (#162225)

Use `TORCH_ONNX_ENABLE_DRAFT_EXPORT` to control whether draft_export should be used as a strategy in onnx export.

Follow up of https://github.com/pytorch/pytorch/pull/161454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162225
Approved by: https://github.com/xadupre, https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-09-05 19:54:50 +00:00
committed by PyTorch MergeBot
parent adae7f66aa
commit 3771380f83
4 changed files with 17 additions and 12 deletions

View File

@ -43,8 +43,8 @@ def _load_boolean_flag(
return state
PLACEHOLDER: bool = _load_boolean_flag(
"TORCH_ONNX_PLACEHOLDER",
this_will="do nothing",
default=True,
ENABLE_DRAFT_EXPORT: bool = _load_boolean_flag(
"TORCH_ONNX_ENABLE_DRAFT_EXPORT",
this_will="enable torch.export.draft_export as a strategy for capturing models",
default=False,
)

View File

@ -12,7 +12,7 @@ import pathlib
from typing import Any, Callable, TYPE_CHECKING
import torch
from torch.export import _draft_export
from torch.onnx import _flags
if TYPE_CHECKING:
@ -251,7 +251,7 @@ class TorchExportDraftExportStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
ep = _draft_export.draft_export(
ep = torch.export.draft_export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
report = ep._report # type: ignore[attr-defined]
@ -263,24 +263,27 @@ class TorchExportDraftExportStrategy(CaptureStrategy):
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`..."
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`..."
)
def _success(self, model) -> None:
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ✅"
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... ✅"
)
def _failure(self, model, e) -> None:
del e # Unused
model_repr = _take_first_line(repr(model))
self._verbose_print(
f"Obtain model graph for `{model_repr}` with `torch.export draft_export`... ❌"
f"Obtain model graph for `{model_repr}` with `torch.export.draft_export`... ❌"
)
CAPTURE_STRATEGIES = (
CAPTURE_STRATEGIES: tuple[type[CaptureStrategy], ...] = (
TorchExportNonStrictStrategy, # strict=False is preferred over strict=True because it does not have dynamo issues
TorchExportStrictStrategy,
)
if _flags.ENABLE_DRAFT_EXPORT:
CAPTURE_STRATEGIES = (*CAPTURE_STRATEGIES, TorchExportDraftExportStrategy)

View File

@ -1340,6 +1340,8 @@ def export(
export_status.torch_export_non_strict = result.success
elif strategy_class is _capture_strategies.TorchExportStrictStrategy:
export_status.torch_export_strict = result.success
elif strategy_class is _capture_strategies.TorchExportDraftExportStrategy:
export_status.torch_export_draft_export = result.success
if result.exception is not None:
failed_results.append(result)

View File

@ -22,7 +22,7 @@ class ExportStatus:
torch_export_strict: bool | None = None
# Whether torch.export.export(..., strict=False) succeeds
torch_export_non_strict: bool | None = None
# Whether torch.export._draft_export.draft_export() succeeds
# Whether torch.export.draft_export() succeeds
torch_export_draft_export: bool | None = None
# Whether decomposition succeeds
decomposition: bool | None = None
@ -47,7 +47,7 @@ def _format_export_status(status: ExportStatus) -> str:
f"```\n"
f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
f"{_status_emoji(status.torch_export_strict)} Obtain model graph with `torch.export.export(..., strict=True)`\n"
f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export._draft_export.draft_export`\n"
f"{_status_emoji(status.torch_export_draft_export)} Obtain model graph with `torch.export.draft_export`\n"
f"{_status_emoji(status.decomposition)} Decompose operators for ONNX compatibility\n"
f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"