mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow fake models to run with ONNXProgram.__call__ (#122230)
In order to a fake model to run using ONNXProgram.__call__ interface, we need to save the model into disk along with external data before executing the model. This is what this PR implements An alternative is to ONNXProgram.__call__ to detect that the model was exported with fake mode and explicit raise an exception when ONNXProgram.__call__ is executed. The exception message would instruct the user to call ONNXProgram.save and manually execute the model using the ONNX runtime of choice. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122230 Approved by: https://github.com/BowenBao ghstack dependencies: #122196
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ba51bb2c4
commit
c4486d3e88
@ -1095,7 +1095,14 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
)
|
||||
|
||||
assert len(ref_outputs) == len(ort_outputs)
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
|
||||
|
||||
# Test ONNXProgram.__call__ interface
|
||||
ort_outputs = onnx_program(
|
||||
*args, model_with_state_dict=real_model, **kwargs
|
||||
)
|
||||
assert len(ref_outputs) == len(ort_outputs)
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
|
||||
|
||||
|
@ -10,6 +10,7 @@ import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
@ -705,25 +706,53 @@ class ONNXProgram:
|
||||
Returns:
|
||||
The model output as computed by ONNX Runtime
|
||||
"""
|
||||
import onnxruntime # type: ignore[import]
|
||||
|
||||
# model specified by the user has precedence, when specified
|
||||
model_with_state_dict = model_with_state_dict or self._model_torch
|
||||
# TODO: If ONNX used absolute paths on the initializers external data files,
|
||||
# users could call ONNXProgram.save and use ONNXProgram.__call__ without the internal save below
|
||||
with contextlib.ExitStack() as stack:
|
||||
# model specified by the user has precedence, when specified
|
||||
model_with_state_dict = model_with_state_dict or self._model_torch
|
||||
|
||||
onnx_input = self.adapt_torch_inputs_to_onnx(
|
||||
*args, model_with_state_dict=model_with_state_dict, **kwargs
|
||||
)
|
||||
options = options or ONNXRuntimeOptions()
|
||||
providers = options.execution_providers or onnxruntime.get_available_providers()
|
||||
onnx_model = self.model_proto.SerializeToString()
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
|
||||
if self.fake_context:
|
||||
tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
|
||||
warnings.warn(
|
||||
"Cannot run model directly from `ONNXProgram` because"
|
||||
" the model was exported using `enable_fake_mode`."
|
||||
" The model will be serialized to disk using a temporary folder ({tmpdir_path})"
|
||||
" to populate the model with initializers before being execution."
|
||||
)
|
||||
# TODO: Revisit the need of `model_with_state_dict` being a real model and not just its state
|
||||
onnx_model = os.path.join(tmpdir_path, "model.onnx")
|
||||
if isinstance(model_with_state_dict, torch.nn.Module):
|
||||
model_state = model_with_state_dict.state_dict()
|
||||
elif isinstance(model_with_state_dict, torch_export.ExportedProgram):
|
||||
model_state = model_with_state_dict.state_dict
|
||||
else:
|
||||
model_state = None
|
||||
self.save(
|
||||
onnx_model,
|
||||
model_state=model_state,
|
||||
)
|
||||
else:
|
||||
onnx_model = self.model_proto.SerializeToString() # type: ignore[assignment]
|
||||
|
||||
onnxruntime_input = {
|
||||
k.name: v.numpy(force=True)
|
||||
for k, v in zip(ort_session.get_inputs(), onnx_input)
|
||||
}
|
||||
import onnxruntime # type: ignore[import]
|
||||
|
||||
return ort_session.run(None, onnxruntime_input)
|
||||
onnx_input = self.adapt_torch_inputs_to_onnx(
|
||||
*args, model_with_state_dict=model_with_state_dict, **kwargs
|
||||
)
|
||||
options = options or ONNXRuntimeOptions()
|
||||
providers = (
|
||||
options.execution_providers or onnxruntime.get_available_providers()
|
||||
)
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model, providers=providers)
|
||||
|
||||
onnxruntime_input = {
|
||||
k.name: v.numpy(force=True)
|
||||
for k, v in zip(ort_session.get_inputs(), onnx_input)
|
||||
}
|
||||
|
||||
return ort_session.run(None, onnxruntime_input)
|
||||
|
||||
@property
|
||||
def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined]
|
||||
|
Reference in New Issue
Block a user