mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[export] fix kwargs in run_decompositions() for training IR (#130553)
Re-exporting GraphModule expects all inputs to be in args, though not in pytree-flattened format. This avoids failing when we run with a fx.Interpreter subclass in [AOTAutograd tracing](973037be6a/torch/_functorch/_aot_autograd/traced_function_transforms.py (L760-L762)
).
Removes 7 test failures for training IR export.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130553
Approved by: https://github.com/zhxchen17, https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
26c2b92525
commit
18b7633bfb
@ -377,8 +377,12 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
constant_attrs = _gather_constant_attrs(mod)
|
||||
aten_export_artifact = _export_to_aten_ir(
|
||||
mod,
|
||||
fake_args_unwrapped[0],
|
||||
fake_args_unwrapped[1],
|
||||
# this requires empty kwargs, but not in pytree.flattened format
|
||||
(
|
||||
*fake_args_unwrapped[0],
|
||||
*fake_args_unwrapped[1].values(),
|
||||
),
|
||||
{},
|
||||
fake_params_buffers,
|
||||
constant_attrs,
|
||||
)
|
||||
|
Reference in New Issue
Block a user