testing infra and some fixes (#162183)

This PR is quite large in that it covers most of rough edges in the new strict export flow:

1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162167
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-08 11:34:12 -07:00
committed by PyTorch MergeBot
parent a965f09793
commit d8b6622bb6
10 changed files with 520 additions and 78 deletions

View File

@ -70,6 +70,7 @@ def export_for_training(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -159,6 +160,7 @@ def export_for_training(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
@ -171,6 +173,7 @@ def export(
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
_use_new_tracer_experimental: bool = False,
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -283,6 +286,7 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_use_new_tracer_experimental=_use_new_tracer_experimental,
)
except Exception as e:
draft_export_msg = (