Allow fake models to run with ONNXProgram.__call__ (#122230)

In order to a fake model to run using ONNXProgram.__call__
interface, we need to save the model into disk along with external data
before executing the model. This is what this PR implements

An alternative is to ONNXProgram.__call__ to detect that the model
was exported with fake mode and explicit raise an exception when
ONNXProgram.__call__ is executed. The exception message would instruct
the user to call ONNXProgram.save and manually execute the model using
the ONNX runtime of choice.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122230
Approved by: https://github.com/BowenBao
ghstack dependencies: #122196
This commit is contained in:
Thiago Crepaldi
2024-03-21 17:59:23 +00:00
committed by PyTorch MergeBot
parent 4ba51bb2c4
commit c4486d3e88
2 changed files with 51 additions and 15 deletions

View File

@ -1095,7 +1095,14 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
)
assert len(ref_outputs) == len(ort_outputs)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
# Test ONNXProgram.__call__ interface
ort_outputs = onnx_program(
*args, model_with_state_dict=real_model, **kwargs
)
assert len(ref_outputs) == len(ort_outputs)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))

View File

@ -10,6 +10,7 @@ import io
import logging
import os
import tempfile
import warnings
from collections import defaultdict
from typing import (
@ -705,17 +706,45 @@ class ONNXProgram:
Returns:
The model output as computed by ONNX Runtime
"""
import onnxruntime # type: ignore[import]
# TODO: If ONNX used absolute paths on the initializers external data files,
# users could call ONNXProgram.save and use ONNXProgram.__call__ without the internal save below
with contextlib.ExitStack() as stack:
# model specified by the user has precedence, when specified
model_with_state_dict = model_with_state_dict or self._model_torch
if self.fake_context:
tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
warnings.warn(
"Cannot run model directly from `ONNXProgram` because"
" the model was exported using `enable_fake_mode`."
" The model will be serialized to disk using a temporary folder ({tmpdir_path})"
" to populate the model with initializers before being execution."
)
# TODO: Revisit the need of `model_with_state_dict` being a real model and not just its state
onnx_model = os.path.join(tmpdir_path, "model.onnx")
if isinstance(model_with_state_dict, torch.nn.Module):
model_state = model_with_state_dict.state_dict()
elif isinstance(model_with_state_dict, torch_export.ExportedProgram):
model_state = model_with_state_dict.state_dict
else:
model_state = None
self.save(
onnx_model,
model_state=model_state,
)
else:
onnx_model = self.model_proto.SerializeToString() # type: ignore[assignment]
import onnxruntime # type: ignore[import]
onnx_input = self.adapt_torch_inputs_to_onnx(
*args, model_with_state_dict=model_with_state_dict, **kwargs
)
options = options or ONNXRuntimeOptions()
providers = options.execution_providers or onnxruntime.get_available_providers()
onnx_model = self.model_proto.SerializeToString()
providers = (
options.execution_providers or onnxruntime.get_available_providers()
)
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
onnxruntime_input = {