mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This hits multi-line logging strings Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/98700 Approved by: https://github.com/voznesenskym
129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
import importlib
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
|
|
import torch
|
|
from .common import device_from_inputs, fake_tensor_unsupported
|
|
from .registry import register_backend
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
_np_dtype = {
|
|
torch.float16: np.float16,
|
|
torch.float32: np.float32,
|
|
torch.float64: np.float64,
|
|
torch.uint8: np.uint8,
|
|
torch.int8: np.int8,
|
|
torch.int16: np.int16,
|
|
torch.int32: np.int32,
|
|
torch.int64: np.longlong,
|
|
torch.bool: np.bool_,
|
|
}
|
|
|
|
except ImportError:
|
|
_np_dtype = None
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def default_provider(device_type):
|
|
if "ONNXRT_PROVIDER" in os.environ:
|
|
return os.environ["ONNXRT_PROVIDER"]
|
|
return {
|
|
"cpu": "CPUExecutionProvider",
|
|
"cuda": "CUDAExecutionProvider",
|
|
# "TensorrtExecutionProvider" is another option
|
|
}[device_type]
|
|
|
|
|
|
def has_onnxruntime():
|
|
try:
|
|
importlib.import_module("onnxruntime")
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@register_backend
|
|
@fake_tensor_unsupported
|
|
def onnxrt(gm, example_inputs, *, filename=None, provider=None):
|
|
if filename is None:
|
|
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
|
|
return onnxrt(gm, example_inputs, filename=tmp.name)
|
|
|
|
import onnxruntime # type: ignore[import]
|
|
|
|
assert _np_dtype, "requires numpy"
|
|
|
|
device_type = device_from_inputs(example_inputs).type
|
|
example_outputs = gm(*example_inputs)
|
|
output_spec = [
|
|
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
|
|
]
|
|
input_names = [f"i{i}" for i in range(len(example_inputs))]
|
|
output_names = [f"o{x}" for x in range(len(example_outputs))]
|
|
|
|
torch.onnx.export(
|
|
torch.jit.script(gm),
|
|
example_inputs,
|
|
filename,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
)
|
|
del example_inputs, example_outputs
|
|
|
|
if provider is None:
|
|
provider = default_provider(device_type)
|
|
assert provider in onnxruntime.get_available_providers()
|
|
session = onnxruntime.InferenceSession(filename, providers=[provider])
|
|
|
|
def _call(*initial_args):
|
|
binding = session.io_binding()
|
|
active_inputs = {inp.name for inp in session.get_inputs()}
|
|
args = [a.contiguous() for a in initial_args]
|
|
for name, value in zip(input_names, args):
|
|
if name not in active_inputs:
|
|
log.warning(
|
|
"input %s skipped as not found in onnx inference session", name
|
|
)
|
|
continue
|
|
dev = value.device
|
|
binding.bind_input(
|
|
name,
|
|
dev.type,
|
|
dev.index or 0,
|
|
_np_dtype[value.dtype],
|
|
value.size(),
|
|
value.data_ptr(),
|
|
)
|
|
outputs = [
|
|
torch.empty(
|
|
shape,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
device=device,
|
|
requires_grad=requires_grad,
|
|
)
|
|
for shape, dtype, layout, device, requires_grad in output_spec
|
|
]
|
|
|
|
for name, value in zip(output_names, outputs):
|
|
dev = value.device
|
|
binding.bind_output(
|
|
name,
|
|
dev.type,
|
|
dev.index or 0,
|
|
_np_dtype[value.dtype],
|
|
value.size(),
|
|
value.data_ptr(),
|
|
)
|
|
session.run_with_iobinding(binding)
|
|
if device_type == "cpu":
|
|
binding.copy_outputs_to_cpu()
|
|
return outputs
|
|
|
|
return _call
|