mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Preserve all legacy exporter params in fallback (#156659)
Fixes #151693 Previous to this PR, the fallback does not take care of all user parameters. This pr preserves them to ensure a smooth transition for users. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156659 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6a8641c8a
commit
a7b29c88b1
@ -364,6 +364,16 @@ def export(
|
||||
|
||||
if isinstance(args, torch.Tensor):
|
||||
args = (args,)
|
||||
# Prepare legacy export parameters for potential fallback
|
||||
legacy_export_kwargs = {
|
||||
"training": training,
|
||||
"operator_export_type": operator_export_type,
|
||||
"do_constant_folding": do_constant_folding,
|
||||
"custom_opsets": custom_opsets,
|
||||
"export_modules_as_functions": export_modules_as_functions,
|
||||
"autograd_inlining": autograd_inlining,
|
||||
}
|
||||
|
||||
return _compat.export_compat(
|
||||
model,
|
||||
args,
|
||||
@ -386,6 +396,7 @@ def export(
|
||||
dump_exported_program=dump_exported_program,
|
||||
artifacts_dir=artifacts_dir,
|
||||
fallback=fallback,
|
||||
legacy_export_kwargs=legacy_export_kwargs,
|
||||
)
|
||||
else:
|
||||
import warnings
|
||||
|
@ -66,6 +66,8 @@ def export_compat(
|
||||
dump_exported_program: bool = False,
|
||||
artifacts_dir: str | os.PathLike = ".",
|
||||
fallback: bool = False,
|
||||
# Legacy export parameters for fallback
|
||||
legacy_export_kwargs: dict[str, Any] | None = None,
|
||||
) -> _onnx_program.ONNXProgram:
|
||||
if opset_version is None:
|
||||
opset_version = _constants.TORCHLIB_OPSET
|
||||
@ -151,6 +153,10 @@ def export_compat(
|
||||
dynamic_axes = _dynamic_shapes.from_dynamic_shapes_to_dynamic_axes(
|
||||
dynamic_shapes=dynamic_shapes, input_names=input_names, exception=e
|
||||
)
|
||||
# Use the legacy export kwargs prepared in __init__.py
|
||||
if legacy_export_kwargs is None:
|
||||
legacy_export_kwargs = {}
|
||||
|
||||
torch.onnx.utils.export(
|
||||
model, # type: ignore[arg-type]
|
||||
args,
|
||||
@ -162,6 +168,7 @@ def export_compat(
|
||||
opset_version=opset_version,
|
||||
dynamic_axes=dynamic_axes,
|
||||
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
||||
**legacy_export_kwargs,
|
||||
)
|
||||
onnx_program = _onnx_program.ONNXProgram(ir.load(f), None)
|
||||
|
||||
|
Reference in New Issue
Block a user