Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272)

This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model.

The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model.

However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output.

This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272
Approved by: https://github.com/titaiwangms, https://github.com/BowenBao
This commit is contained in:
Thiago Crepaldi
2023-11-22 22:13:48 +00:00
committed by PyTorch MergeBot
parent 7daeb6509f
commit a76bb5d84d
4 changed files with 77 additions and 4 deletions

View File

@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort(
ref_input_args = input_args
ref_input_kwargs = input_kwargs
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
ref_model(*ref_input_args, **ref_input_kwargs)
)
# ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
ort_outputs = onnx_program(*input_args, **input_kwargs)
ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
if len(ref_outputs) != len(ort_outputs):
raise AssertionError(
f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol

View File

@ -942,6 +942,37 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
loaded_exported_program, (x,), skip_dynamic_shapes_check=True
)
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
"Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. "
"github issue: https://github.com/pytorch/pytorch/issues/114406"
)
def test_exported_program_as_input_lifting_buffers_mutation(self):
for persistent in (True, False):
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"my_buffer", torch.tensor(4.0), persistent=persistent
)
def forward(self, x, b):
output = x + b
(
self.my_buffer.add_(1.0) + 3.0
) # Mutate buffer through in-place addition
return output
inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3))
model = CustomModule()
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model, inputs, skip_dynamic_shapes_check=True
)
# Buffer will be mutated after the first iteration
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model, inputs, skip_dynamic_shapes_check=True
)
def _parameterized_class_attrs_and_values_with_fake_options():
input_values = []

View File

@ -63,6 +63,10 @@ class TorchExport(exporter.FXGraphExtractor):
# tensor, etc), we flatten the collection and register each element as output.
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
options.fx_tracer.output_adapter.append_step(
io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model)
)
# Export FX graph to ONNX ModelProto.
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]

View File

@ -550,3 +550,39 @@ class PrependParamsAndBuffersAotAutogradInputStep(InputAdaptStep):
if model_kwargs:
return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs)
return updated_args, {}
class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
"""Prepend model's mutated buffers to the user output.
:func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
must be added to the user output after the model is executed.
Args:
model: The PyTorch model with mutated buffers.
"""
def __init__(self, model: torch_export.ExportedProgram):
assert isinstance(
model, torch_export.ExportedProgram
), "'model' must be a torch.export.ExportedProgram."
self.model = model
def apply(self, model_outputs: Any) -> Sequence[Any]:
"""Flatten the model outputs and validate the `SpecTree` output.
Args:
model_outputs: The model outputs to flatten.
Returns:
flattened_outputs: The flattened model outputs.
"""
ordered_buffers = tuple(
self.model.state_dict[name]
for name in self.model.graph_signature.buffers_to_mutate.values()
)
# NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
updated_outputs = (*ordered_buffers, *model_outputs)
return updated_outputs