[ONNX] Preserve all legacy exporter params in fallback (#156659)

Fixes #151693

Previous to this PR, the fallback does not take care of all user parameters. This pr preserves them to ensure a smooth transition for users.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156659
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang
2025-06-24 05:28:50 +00:00
committed by PyTorch MergeBot
parent a6a8641c8a
commit a7b29c88b1
2 changed files with 18 additions and 0 deletions

View File

@ -364,6 +364,16 @@ def export(
if isinstance(args, torch.Tensor):
args = (args,)
# Prepare legacy export parameters for potential fallback
legacy_export_kwargs = {
"training": training,
"operator_export_type": operator_export_type,
"do_constant_folding": do_constant_folding,
"custom_opsets": custom_opsets,
"export_modules_as_functions": export_modules_as_functions,
"autograd_inlining": autograd_inlining,
}
return _compat.export_compat(
model,
args,
@ -386,6 +396,7 @@ def export(
dump_exported_program=dump_exported_program,
artifacts_dir=artifacts_dir,
fallback=fallback,
legacy_export_kwargs=legacy_export_kwargs,
)
else:
import warnings

View File

@ -66,6 +66,8 @@ def export_compat(
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
# Legacy export parameters for fallback
legacy_export_kwargs: dict[str, Any] | None = None,
) -> _onnx_program.ONNXProgram:
if opset_version is None:
opset_version = _constants.TORCHLIB_OPSET
@ -151,6 +153,10 @@ def export_compat(
dynamic_axes = _dynamic_shapes.from_dynamic_shapes_to_dynamic_axes(
dynamic_shapes=dynamic_shapes, input_names=input_names, exception=e
)
# Use the legacy export kwargs prepared in __init__.py
if legacy_export_kwargs is None:
legacy_export_kwargs = {}
torch.onnx.utils.export(
model, # type: ignore[arg-type]
args,
@ -162,6 +168,7 @@ def export_compat(
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
**legacy_export_kwargs,
)
onnx_program = _onnx_program.ONNXProgram(ir.load(f), None)