[export] fix kwargs in run_decompositions() for training IR (#130553)

Re-exporting GraphModule expects all inputs to be in args, though not in pytree-flattened format. This avoids failing when we run with a fx.Interpreter subclass in [AOTAutograd tracing](973037be6a/torch/_functorch/_aot_autograd/traced_function_transforms.py (L760-L762)).

Removes 7 test failures for training IR export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130553
Approved by: https://github.com/zhxchen17, https://github.com/ydwu4
This commit is contained in:
Pian Pawakapan
2024-07-11 22:53:18 +00:00
committed by PyTorch MergeBot
parent 26c2b92525
commit 18b7633bfb
2 changed files with 6 additions and 9 deletions

View File

@ -377,8 +377,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
constant_attrs = _gather_constant_attrs(mod)
aten_export_artifact = _export_to_aten_ir(
mod,
fake_args_unwrapped[0],
fake_args_unwrapped[1],
# this requires empty kwargs, but not in pytree.flattened format
(
*fake_args_unwrapped[0],
*fake_args_unwrapped[1].values(),
),
{},
fake_params_buffers,
constant_attrs,
)