mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add support for models with mutated buffer on torch.onnx.dynamo_export (#112272)
This PR adds a unit test that leverages `torch.export.ExportedProgram` models that mutates registered buffers. Although the exporter already works out of the box in such scenario, the GraphModule and the exported ONNX model have extra outputs containing the mutated buffers. On future runs of the ONNX model, the mutated buffers are used as input to the model. The aforementioned extra inputs and outputs are by design and the `ONNXProgram.model_signature` can be used to fetch detailed input/output schema for the exported model. However, when we want to compare pytorch output to ONNX's, there is a mismatch between the schema because pytorch output does not include the mutated buffers present on the ONNX output. This PR extends `onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)` so that the mutated buffers are prepended to the Pytorch output, matching the ONNX schema. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112272 Approved by: https://github.com/titaiwangms, https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
7daeb6509f
commit
a76bb5d84d
@ -436,16 +436,18 @@ def _compare_pytorch_onnx_with_ort(
|
||||
ref_input_args = input_args
|
||||
ref_input_kwargs = input_kwargs
|
||||
|
||||
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(
|
||||
ref_model(*ref_input_args, **ref_input_kwargs)
|
||||
)
|
||||
|
||||
# ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
|
||||
# Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
|
||||
# Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
|
||||
ort_outputs = onnx_program(*input_args, **input_kwargs)
|
||||
ref_outputs = ref_model(*ref_input_args, **ref_input_kwargs)
|
||||
ref_outputs = onnx_program.adapt_torch_outputs_to_onnx(ref_outputs)
|
||||
|
||||
if len(ref_outputs) != len(ort_outputs):
|
||||
raise AssertionError(
|
||||
f"Expected {len(ref_outputs)} outputs, got {len(ort_outputs)}"
|
||||
)
|
||||
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
torch.testing.assert_close(
|
||||
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
|
||||
|
@ -942,6 +942,37 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
loaded_exported_program, (x,), skip_dynamic_shapes_check=True
|
||||
)
|
||||
|
||||
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
|
||||
"Unsupported FX nodes: {'call_function': ['aten.add_.Tensor']}. "
|
||||
"github issue: https://github.com/pytorch/pytorch/issues/114406"
|
||||
)
|
||||
def test_exported_program_as_input_lifting_buffers_mutation(self):
|
||||
for persistent in (True, False):
|
||||
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"my_buffer", torch.tensor(4.0), persistent=persistent
|
||||
)
|
||||
|
||||
def forward(self, x, b):
|
||||
output = x + b
|
||||
(
|
||||
self.my_buffer.add_(1.0) + 3.0
|
||||
) # Mutate buffer through in-place addition
|
||||
return output
|
||||
|
||||
inputs = (torch.rand((3, 3), dtype=torch.float32), torch.randn(3, 3))
|
||||
model = CustomModule()
|
||||
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
|
||||
model, inputs, skip_dynamic_shapes_check=True
|
||||
)
|
||||
# Buffer will be mutated after the first iteration
|
||||
self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
|
||||
model, inputs, skip_dynamic_shapes_check=True
|
||||
)
|
||||
|
||||
|
||||
def _parameterized_class_attrs_and_values_with_fake_options():
|
||||
input_values = []
|
||||
|
@ -63,6 +63,10 @@ class TorchExport(exporter.FXGraphExtractor):
|
||||
# tensor, etc), we flatten the collection and register each element as output.
|
||||
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
|
||||
|
||||
options.fx_tracer.output_adapter.append_step(
|
||||
io_adapter.PrependParamsAndBuffersAotAutogradOutputStep(model)
|
||||
)
|
||||
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
return self.pre_export_passes(options, model, model.graph_module, updated_model_args) # type: ignore[return-value]
|
||||
|
||||
|
@ -550,3 +550,39 @@ class PrependParamsAndBuffersAotAutogradInputStep(InputAdaptStep):
|
||||
if model_kwargs:
|
||||
return MergeKwargsIntoArgsInputStep().apply(updated_args, model_kwargs)
|
||||
return updated_args, {}
|
||||
|
||||
|
||||
class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
|
||||
"""Prepend model's mutated buffers to the user output.
|
||||
|
||||
:func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
|
||||
must be added to the user output after the model is executed.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model with mutated buffers.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch_export.ExportedProgram):
|
||||
assert isinstance(
|
||||
model, torch_export.ExportedProgram
|
||||
), "'model' must be a torch.export.ExportedProgram."
|
||||
self.model = model
|
||||
|
||||
def apply(self, model_outputs: Any) -> Sequence[Any]:
|
||||
"""Flatten the model outputs and validate the `SpecTree` output.
|
||||
|
||||
Args:
|
||||
model_outputs: The model outputs to flatten.
|
||||
|
||||
Returns:
|
||||
flattened_outputs: The flattened model outputs.
|
||||
"""
|
||||
|
||||
ordered_buffers = tuple(
|
||||
self.model.state_dict[name]
|
||||
for name in self.model.graph_signature.buffers_to_mutate.values()
|
||||
)
|
||||
|
||||
# NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
|
||||
updated_outputs = (*ordered_buffers, *model_outputs)
|
||||
return updated_outputs
|
||||
|
Reference in New Issue
Block a user