mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	beartype has served us well in identifying type errors and ensuring we call internal functions with the correct arguments (thanks!). However, the value of having beartype is diminished because of the following: 1. When beartype improves support for better Dict[] type checking, it discovered typing mistakes in some functions that were previously uncaught. This caused the exporter to fail with newer versions beartype when it used to succeed. Since we cannot fix PyTorch and release a new version just because of this, it creates confusion for users that have beartype in their environment from using torch.onnx 2. beartype adds an additional call line in the traceback, which makes the already thick dynamo stack even larger, affecting readability when users diagnose errors with the traceback. 3. Since the typing annotations need to be evaluated, we cannot use new syntaxes like `|` because we need to maintain compatibility with Python 3.8. We don't want to wait for PyTorch take py310 as the lowest supported Python before using the new typing syntaxes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130484 Approved by: https://github.com/titaiwangms
		
			
				
	
	
		
			1217 lines
		
	
	
		
			50 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1217 lines
		
	
	
		
			50 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# mypy: allow-untyped-defs
 | 
						|
import dataclasses
 | 
						|
import importlib
 | 
						|
import logging
 | 
						|
import os
 | 
						|
 | 
						|
from typing import (
 | 
						|
    Any,
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Final,
 | 
						|
    List,
 | 
						|
    Mapping,
 | 
						|
    Optional,
 | 
						|
    Sequence,
 | 
						|
    Set,
 | 
						|
    Tuple,
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Union,
 | 
						|
)
 | 
						|
from typing_extensions import TypeAlias
 | 
						|
 | 
						|
import torch
 | 
						|
import torch._C
 | 
						|
import torch._ops
 | 
						|
import torch._prims.executor
 | 
						|
import torch.fx
 | 
						|
from torch._subclasses.fake_tensor import FakeTensor
 | 
						|
from torch.fx._compatibility import compatibility
 | 
						|
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
 | 
						|
from torch.fx.passes.operator_support import OperatorSupport
 | 
						|
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
 | 
						|
from torch.utils import _pytree
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    import onnx
 | 
						|
 | 
						|
try:
 | 
						|
    # Use try-except to initialize package-dependent global variables.
 | 
						|
    import onnxruntime  # type: ignore[import]
 | 
						|
    from onnxruntime.capi import _pybind_state as ORTC  # type: ignore[import]
 | 
						|
 | 
						|
    # This is not use directly in DORT but needed by underlying exporter,
 | 
						|
    # so we still need to check if it exists.
 | 
						|
    importlib.import_module("onnxscript")
 | 
						|
 | 
						|
    import torch.onnx
 | 
						|
    import torch.onnx._internal
 | 
						|
    import torch.onnx._internal.diagnostics
 | 
						|
    import torch.onnx._internal.exporter
 | 
						|
    import torch.onnx._internal.fx.decomposition_table
 | 
						|
    import torch.onnx._internal.fx.passes
 | 
						|
    from torch.onnx._internal.fx import fx_onnx_interpreter
 | 
						|
    from torch.onnx._internal.fx.type_utils import (
 | 
						|
        _TORCH_DTYPE_TO_NUMPY_DTYPE,
 | 
						|
        _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
 | 
						|
        from_python_type_to_onnx_tensor_element_type,
 | 
						|
    )
 | 
						|
 | 
						|
    _SUPPORT_ONNXRT = True
 | 
						|
except ImportError:
 | 
						|
    _SUPPORT_ONNXRT = False
 | 
						|
 | 
						|
__all__ = [
 | 
						|
    "is_onnxrt_backend_supported",
 | 
						|
    "torch_compile_backend",
 | 
						|
    "OrtExecutionProvider",
 | 
						|
    "OrtBackendOptions",
 | 
						|
    "OrtBackend",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def is_onnxrt_backend_supported() -> bool:
 | 
						|
    """Returns ``True`` if ONNX Runtime dependencies are installed and usable
 | 
						|
    to support TorchDynamo backend integration; ``False`` otherwise.
 | 
						|
 | 
						|
    Example::
 | 
						|
 | 
						|
        # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
 | 
						|
        >>> import torch
 | 
						|
        >>> if torch.onnx.is_onnxrt_backend_supported():
 | 
						|
        ...     @torch.compile(backend="onnxrt")
 | 
						|
        ...     def f(x):
 | 
						|
        ...             return x * x
 | 
						|
        ...     print(f(torch.randn(10)))
 | 
						|
        ... else:
 | 
						|
        ...     print("pip install onnx onnxscript onnxruntime")
 | 
						|
        ...
 | 
						|
    """
 | 
						|
    return _SUPPORT_ONNXRT
 | 
						|
 | 
						|
 | 
						|
_dumped_onnx_model: Dict[str, int] = {}
 | 
						|
 | 
						|
 | 
						|
def _dump_onnx_model(
 | 
						|
    model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None
 | 
						|
) -> str:
 | 
						|
    """Stores the onnx model into a file.
 | 
						|
    The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
 | 
						|
    where *N* is the number of files already stored with
 | 
						|
    this prefix.
 | 
						|
    If graph_module is not None, the graph is stored as a string with
 | 
						|
    the same filename except the extension (.txt).
 | 
						|
    """
 | 
						|
    prefix = os.environ.get("ONNXRT_DUMP_PATH", None)
 | 
						|
    if not prefix:
 | 
						|
        return ""
 | 
						|
    n = _dumped_onnx_model.get(prefix, -1) + 1
 | 
						|
    filename = f"{prefix}{n}.onnx"
 | 
						|
    with open(filename, "wb") as f:
 | 
						|
        f.write(model_string)
 | 
						|
    _dumped_onnx_model[prefix] = n
 | 
						|
    if graph_module is not None:
 | 
						|
        filename_txt = f"{prefix}{n}.txt"
 | 
						|
        with open(filename_txt, "w", encoding="utf-8") as f:
 | 
						|
            f.write(str(graph_module.graph))
 | 
						|
    return filename
 | 
						|
 | 
						|
 | 
						|
def _infer_default_eps() -> Sequence[str]:
 | 
						|
    # TODO: select a good default based on the capabilities of the host
 | 
						|
    # e.g. DML on Windows, etc.
 | 
						|
    return ["CPUExecutionProvider"]
 | 
						|
 | 
						|
 | 
						|
def _nvtx_range_push(name: str):
 | 
						|
    """If PyTorch is installed with CUDA support, this starts NVTX range.
 | 
						|
 | 
						|
    Check torch.cuda.nvtx.range_push's document for more details.
 | 
						|
    """
 | 
						|
    if torch.cuda.is_available():
 | 
						|
        torch.cuda.nvtx.range_push(name)
 | 
						|
 | 
						|
 | 
						|
def _nvtx_range_pop():
 | 
						|
    """If PyTorch is installed with CUDA support, this terminates NVTX range.
 | 
						|
 | 
						|
    Check torch.cuda.nvtx.range_pop's document for more details.
 | 
						|
    """
 | 
						|
    if torch.cuda.is_available():
 | 
						|
        torch.cuda.nvtx.range_pop()
 | 
						|
 | 
						|
 | 
						|
def _get_ort_device_type(device_type: str):
 | 
						|
    if device_type == "cuda":
 | 
						|
        return ORTC.OrtDevice.cuda()
 | 
						|
    if device_type == "cpu":
 | 
						|
        return ORTC.OrtDevice.cpu()
 | 
						|
    # ort pytorch device is mapped to NPU OrtDevice type
 | 
						|
    if device_type == "maia":
 | 
						|
        return ORTC.OrtDevice.npu()
 | 
						|
    raise ValueError("Unsupported device type: " + device_type)
 | 
						|
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
# Uncomment the following lines to print out development info.
 | 
						|
# logging.basicConfig(level=logging.WARNING)
 | 
						|
# logger.setLevel(logging.WARNING)
 | 
						|
 | 
						|
 | 
						|
class OrtOperatorSupport(OperatorSupport):
 | 
						|
    """Operator support for ONNXRuntime backend.
 | 
						|
 | 
						|
    It has two-level of support decision. One is via support_dict and the other one
 | 
						|
    is via extra_support_dict. The logic of using support_dict is implemented in
 | 
						|
    OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
 | 
						|
        # Use extra_support_dict[op_name] = None to indicate
 | 
						|
        # we support op_name with all input types. Otherwise,
 | 
						|
        # see support_dict (type: SupportDict) in operator_support.py
 | 
						|
        # for specifying supported types.
 | 
						|
        super().__init__(extra_support_dict)
 | 
						|
        self._onnx_support_dict = support_dict
 | 
						|
 | 
						|
    def is_node_supported(
 | 
						|
        self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
 | 
						|
    ) -> bool:
 | 
						|
        # OperatorSupport.is_node_supported returns True for non-callable nodes.
 | 
						|
        # Since ORT can't execute them, we return False here to override the base
 | 
						|
        # behavior.
 | 
						|
        if node.op not in CALLABLE_NODE_OPS:
 | 
						|
            return False
 | 
						|
        # This is the and the only place to decide if aten op is supported.
 | 
						|
        if node.op == "call_function" and node.target in self._onnx_support_dict:
 | 
						|
            logger.info(
 | 
						|
                "support_dict supports node.target: %s (type: %s)",
 | 
						|
                node.target,
 | 
						|
                type(node.target),
 | 
						|
            )
 | 
						|
            return True
 | 
						|
        # If node.target is not in support_dict, we still want to check if torch.jit.script
 | 
						|
        # can convert it to ONNX equivalence. Let's use base mechanism to do this.
 | 
						|
        # See extra_support_dict  for supported ops.
 | 
						|
        if super().is_node_supported(submodules, node):
 | 
						|
            logger.info(
 | 
						|
                "extra_support_dict supports node.target: %s (type: %s)",
 | 
						|
                node.target,
 | 
						|
                type(node.target),
 | 
						|
            )
 | 
						|
            return True
 | 
						|
        logger.warning(
 | 
						|
            "support_dict and extra_support_dict don't support node.target: %s (type: %s)",
 | 
						|
            node.target,
 | 
						|
            type(node.target),
 | 
						|
        )
 | 
						|
        return False
 | 
						|
 | 
						|
 | 
						|
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
 | 
						|
    """
 | 
						|
    In torch.fx.Graph, placeholder is a special assignment node. If it's not
 | 
						|
    executed in the beginning, it could overwrite values computed by upstream
 | 
						|
    nodes.
 | 
						|
    """
 | 
						|
 | 
						|
    graph = graph_module.graph
 | 
						|
    placeholders = []
 | 
						|
    first_not_placeholder = None
 | 
						|
    for node in graph.nodes:
 | 
						|
        if node.op == "placeholder":
 | 
						|
            placeholders.append(node)
 | 
						|
        if first_not_placeholder is None and node.op != "placeholder":
 | 
						|
            first_not_placeholder = node
 | 
						|
    if first_not_placeholder is None:
 | 
						|
        return
 | 
						|
    for placeholder in placeholders:
 | 
						|
        first_not_placeholder.prepend(placeholder)
 | 
						|
 | 
						|
 | 
						|
def _infer_ep_from_device(*args) -> Tuple[str, ...]:
 | 
						|
    """Return the first valid device (i.e., GPU or CPU) in argument list."""
 | 
						|
    eps = []
 | 
						|
    for arg in args:
 | 
						|
        if hasattr(arg, "device"):
 | 
						|
            device = arg.device
 | 
						|
            if device.type == "cuda":
 | 
						|
                eps.append("CUDAExecutionProvider")
 | 
						|
            elif device.type == "cpu":
 | 
						|
                eps.append("CPUExecutionProvider")
 | 
						|
    return tuple(eps)
 | 
						|
 | 
						|
 | 
						|
def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
 | 
						|
    placeholders = []
 | 
						|
    for node in graph_module.graph.nodes:
 | 
						|
        if node.op == "placeholder":
 | 
						|
            if hasattr(node, "meta") and "val" in node.meta:
 | 
						|
                assert isinstance(node.meta["val"], torch.Tensor)
 | 
						|
            placeholders.append(node)
 | 
						|
    return tuple(placeholders)
 | 
						|
 | 
						|
 | 
						|
def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
 | 
						|
    """Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
 | 
						|
    for node in graph_module.graph.nodes:
 | 
						|
        if node.op == "output":
 | 
						|
            # Output node is unique. Let's retrieve output values from
 | 
						|
            # this node's input list. And then just return.
 | 
						|
            return node.args[0]
 | 
						|
    raise ValueError("No output node found in this torch.fx.GraphModule.")
 | 
						|
 | 
						|
 | 
						|
def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
 | 
						|
    """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
 | 
						|
    flattened_output_args, _ = _pytree.tree_flatten(
 | 
						|
        _extract_graph_module_outputs(graph_module)
 | 
						|
    )
 | 
						|
    # Output arguments with example value (type: torch.Tensor) in the `graph_module`.
 | 
						|
    selected_output_args = [
 | 
						|
        output_arg.meta["val"]
 | 
						|
        for output_arg in flattened_output_args
 | 
						|
        # output_arg must have tensor for its device information.
 | 
						|
        # Otherwise, skip it.
 | 
						|
        if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
 | 
						|
    ]
 | 
						|
    return _infer_ep_from_device(*selected_output_args)
 | 
						|
 | 
						|
 | 
						|
def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
 | 
						|
    """Sort execution providers in eps based on pre-set priority."""
 | 
						|
 | 
						|
    def get_execution_provider_priority(ep: str) -> int:
 | 
						|
        if ep == "CPUExecutionProvider":
 | 
						|
            # Lowest priority.
 | 
						|
            return 2
 | 
						|
        if ep == "CUDAExecutionProvider":
 | 
						|
            # Higher priority than CPU but lower than
 | 
						|
            # other specialized EPs.
 | 
						|
            return 1
 | 
						|
        # Highest priority.
 | 
						|
        return 0
 | 
						|
 | 
						|
    unique_eps = set(eps)
 | 
						|
    return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
 | 
						|
 | 
						|
 | 
						|
def _get_onnx_devices(
 | 
						|
    values: Tuple[
 | 
						|
        Union[
 | 
						|
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
 | 
						|
        ],
 | 
						|
        ...,
 | 
						|
    ]
 | 
						|
) -> Tuple["ORTC.OrtDevice", ...]:
 | 
						|
    def _device_id_or_zero(device_id: int) -> int:
 | 
						|
        return device_id or 0
 | 
						|
 | 
						|
    def _map_tensor_or_sym_to_device(
 | 
						|
        value: Union[
 | 
						|
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
 | 
						|
        ],
 | 
						|
    ) -> int:
 | 
						|
        if isinstance(value, torch.Tensor):
 | 
						|
            return ORTC.OrtDevice(
 | 
						|
                _get_ort_device_type(value.device.type),
 | 
						|
                ORTC.OrtDevice.default_memory(),
 | 
						|
                _device_id_or_zero(value.device.index),
 | 
						|
            )
 | 
						|
        elif isinstance(
 | 
						|
            value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool)
 | 
						|
        ):
 | 
						|
            return ORTC.OrtDevice(
 | 
						|
                _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            raise ValueError("Unsupported value type: " + str(type(value)))
 | 
						|
 | 
						|
    if len(values) > 0:
 | 
						|
        ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values)
 | 
						|
        return ort_devices
 | 
						|
    else:
 | 
						|
        return (_map_tensor_or_sym_to_device(1),)
 | 
						|
 | 
						|
 | 
						|
def _get_ortvalues_from_torch_tensors(
 | 
						|
    tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
 | 
						|
) -> Tuple[torch.Tensor, ...]:
 | 
						|
    ortvalues = ORTC.OrtValueVector()
 | 
						|
    ortvalues.reserve(len(tensors))
 | 
						|
    dtypes = []
 | 
						|
    shapes = []
 | 
						|
    data_ptrs = []
 | 
						|
 | 
						|
    for tensor in tensors:
 | 
						|
        dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
 | 
						|
        shapes.append(tensor.size())
 | 
						|
        data_ptrs.append(tensor.data_ptr())
 | 
						|
    ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
 | 
						|
    return ortvalues
 | 
						|
 | 
						|
 | 
						|
def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
 | 
						|
    if tensor.is_sparse:
 | 
						|
        raise ValueError("sparse tensor is not yet supported.")
 | 
						|
    out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
 | 
						|
    return out
 | 
						|
 | 
						|
 | 
						|
def _adjust_scalar_from_fx_to_onnx(
 | 
						|
    dynamo_value: Union[
 | 
						|
        torch.Tensor,
 | 
						|
        int,
 | 
						|
        float,
 | 
						|
        bool,
 | 
						|
    ],
 | 
						|
    value_info: "onnx.ValueInfoProto",  # type: ignore[name-defined]
 | 
						|
) -> torch.Tensor:
 | 
						|
    """Helper function to wrap PyTorch variables as torch.Tensor"""
 | 
						|
    if (
 | 
						|
        isinstance(dynamo_value, torch.Tensor)
 | 
						|
        and len(value_info.type.tensor_type.shape.dim) == 0
 | 
						|
        and dynamo_value.shape == (1,)
 | 
						|
    ):
 | 
						|
        # ONNX expect a scalar with empty shape.
 | 
						|
        # In contrast, PyTorch usually allows implicit
 | 
						|
        # conversion between shape=() and shape=(1,).
 | 
						|
        #
 | 
						|
        # Below, PyTorch's shape (1,) is reshaped to ().
 | 
						|
        return torch.squeeze(dynamo_value)
 | 
						|
    elif isinstance(dynamo_value, int):
 | 
						|
        return torch.tensor(dynamo_value, dtype=torch.int64)
 | 
						|
    elif isinstance(dynamo_value, float):
 | 
						|
        return torch.tensor(dynamo_value, dtype=torch.float32)
 | 
						|
    elif isinstance(dynamo_value, bool):
 | 
						|
        return torch.tensor(dynamo_value, dtype=torch.bool)
 | 
						|
    else:
 | 
						|
        assert isinstance(dynamo_value, torch.Tensor)
 | 
						|
        return dynamo_value.contiguous()
 | 
						|
 | 
						|
 | 
						|
def _adjust_scalar_from_onnx_to_fx(
 | 
						|
    tensor: torch.Tensor,
 | 
						|
    prim_value: Union[
 | 
						|
        torch.Tensor,
 | 
						|
        torch.SymInt,
 | 
						|
        int,
 | 
						|
        torch.SymFloat,
 | 
						|
        float,
 | 
						|
        torch.SymBool,
 | 
						|
        bool,
 | 
						|
    ],
 | 
						|
) -> Union[torch.Tensor, int, float, bool,]:
 | 
						|
    """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
 | 
						|
    assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
 | 
						|
    if isinstance(
 | 
						|
        prim_value,
 | 
						|
        (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool),
 | 
						|
    ):
 | 
						|
        # Convert tensor back to scalar to match Dynamo's expectation.
 | 
						|
        return tensor.item()
 | 
						|
    return tensor
 | 
						|
 | 
						|
 | 
						|
def _run_onnx_session_with_ortvaluevector(
 | 
						|
    sess: "onnxruntime.InferenceSession",
 | 
						|
    input_names: Tuple[str, ...],
 | 
						|
    inputs: Tuple[torch.Tensor, ...],
 | 
						|
    input_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
    output_names: Tuple[str, ...],
 | 
						|
    outputs: Tuple[torch.Tensor, ...],
 | 
						|
    output_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
    preallocate_output: bool,
 | 
						|
    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
 | 
						|
    normalized_prim_outputs: Tuple[
 | 
						|
        Union[
 | 
						|
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
 | 
						|
        ],
 | 
						|
        ...,
 | 
						|
    ],
 | 
						|
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
 | 
						|
    _nvtx_range_push("contiguous")
 | 
						|
    inputs = tuple(
 | 
						|
        _adjust_scalar_from_fx_to_onnx(arg, value_info)
 | 
						|
        for arg, value_info in zip(inputs, input_value_infos)
 | 
						|
    )
 | 
						|
    _nvtx_range_pop()
 | 
						|
 | 
						|
    _nvtx_range_push("push_back_batch")
 | 
						|
    ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)
 | 
						|
 | 
						|
    # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
 | 
						|
    # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
 | 
						|
    # to torch Tensor transferring the ownership.
 | 
						|
    if preallocate_output:
 | 
						|
        pth_outputs = tuple(
 | 
						|
            _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs
 | 
						|
        )
 | 
						|
        ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
 | 
						|
    else:
 | 
						|
        ort_outputs = ORTC.OrtValueVector()
 | 
						|
    _nvtx_range_pop()
 | 
						|
 | 
						|
    _nvtx_range_push("run_with_ortvaluevector")
 | 
						|
    run_options = onnxruntime.RunOptions()
 | 
						|
    run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
 | 
						|
    sess.run_with_ortvaluevector(
 | 
						|
        run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices
 | 
						|
    )
 | 
						|
    _nvtx_range_pop()
 | 
						|
 | 
						|
    # Post-processing step:
 | 
						|
    #  wrap ORT's outputs to the schema represented by
 | 
						|
    #  `prim_output` (obtained by running the original
 | 
						|
    #  torch.fx.GraphModule).
 | 
						|
    if preallocate_output:
 | 
						|
        # Profile the ORT-to-PyTorch type cast below
 | 
						|
        _nvtx_range_push("after run_with_ortvaluevector")
 | 
						|
        # Outputs are stored on pre-allocated torch.Tensors' memory,
 | 
						|
        # so this case doesn't need to convert ORTValue to torch.Tensor.
 | 
						|
        pth_outputs = tuple(
 | 
						|
            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
 | 
						|
            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
 | 
						|
        )
 | 
						|
        _nvtx_range_pop()
 | 
						|
        return pth_outputs
 | 
						|
    else:
 | 
						|
        # Profile the two ORT-to-PyTorch type casts below
 | 
						|
        _nvtx_range_push("after run_with_ortvaluevector")
 | 
						|
        # Map ORTValue to torch.Tensor.
 | 
						|
        pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(
 | 
						|
            ort_outputs
 | 
						|
        )
 | 
						|
        # Change some torch.Tensor to int, float, bool.
 | 
						|
        pth_outputs = tuple(
 | 
						|
            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
 | 
						|
            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
 | 
						|
        )
 | 
						|
        _nvtx_range_pop()
 | 
						|
        return pth_outputs
 | 
						|
 | 
						|
 | 
						|
def _run_onnx_session_with_fetch(
 | 
						|
    sess: "onnxruntime.InferenceSession",
 | 
						|
    input_names: Tuple[str, ...],
 | 
						|
    inputs: Tuple[torch.Tensor, ...],
 | 
						|
    input_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
    output_names: Tuple[str, ...],
 | 
						|
    outputs: Tuple[torch.Tensor, ...],
 | 
						|
    output_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
    preallocate_output: bool,
 | 
						|
    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
 | 
						|
    normalized_prim_outputs: Tuple[
 | 
						|
        Union[
 | 
						|
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
 | 
						|
        ],
 | 
						|
        ...,
 | 
						|
    ],
 | 
						|
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
 | 
						|
    inputs = tuple(
 | 
						|
        _adjust_scalar_from_fx_to_onnx(arg, value_info)
 | 
						|
        for arg, value_info in zip(inputs, input_value_infos)
 | 
						|
    )
 | 
						|
    feed = {
 | 
						|
        name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy())
 | 
						|
        for name, tensor in zip(input_names, inputs)
 | 
						|
    }
 | 
						|
    ort_outputs = sess.run(output_names, feed)
 | 
						|
    pth_outputs = tuple(
 | 
						|
        _adjust_scalar_from_onnx_to_fx(
 | 
						|
            torch.from_numpy(value),
 | 
						|
            prim_output,
 | 
						|
        )
 | 
						|
        for value, prim_output in zip(ort_outputs, normalized_prim_outputs)
 | 
						|
    )
 | 
						|
    return pth_outputs
 | 
						|
 | 
						|
 | 
						|
class OrtExecutionInfoPerSession:
 | 
						|
    """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        session: "onnxruntime.InferenceSession",
 | 
						|
        input_names: Tuple[str, ...],
 | 
						|
        input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
 | 
						|
        output_names: Tuple[str, ...],
 | 
						|
        output_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
 | 
						|
        input_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
        output_devices: Tuple["ORTC.OrtDevice", ...],
 | 
						|
        example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
 | 
						|
    ):
 | 
						|
        # Carrier of ONNX model and its executor.
 | 
						|
        self.session: onnxruntime.InferenceSession = session
 | 
						|
        # For the ONNX model stored in self.session, self.input_names[i] is the
 | 
						|
        # name of the i-th positional input.
 | 
						|
        self.input_names: Tuple[str, ...] = input_names
 | 
						|
        # self.input_name[i]'s type information is stored in self.input_value_infos[i].
 | 
						|
        self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos  # type: ignore[name-defined]
 | 
						|
        # Similar to self.input_names, but for outputs.
 | 
						|
        self.output_names: Tuple[str, ...] = output_names
 | 
						|
        # Similar to self.input_value_infos but for outputs.
 | 
						|
        self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos  # type: ignore[name-defined]
 | 
						|
        # For the ONNX model stored in self.session, self.input_devices[i] is the
 | 
						|
        # i-th positional input's device.
 | 
						|
        self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices
 | 
						|
        # Similar to self.input_devices, but for outputs.
 | 
						|
        self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
 | 
						|
        # This is the outputs of executing the original torch.fx.GraphModule with example inputs
 | 
						|
        # (i.e., args passed into OrtBackend._ort_acclerated_call).
 | 
						|
        self.example_outputs: Union[
 | 
						|
            Tuple[torch.Tensor, ...], torch.Tensor
 | 
						|
        ] = example_outputs
 | 
						|
 | 
						|
    def is_supported(self, *args):
 | 
						|
        # Compare the args and the input schema in ONNX model and
 | 
						|
        # return the first match.
 | 
						|
        if len(args) != len(self.input_value_infos):
 | 
						|
            return False
 | 
						|
        for arg, value_info in zip(args, self.input_value_infos):
 | 
						|
            if not isinstance(arg, (torch.Tensor, float, int)):
 | 
						|
                return False
 | 
						|
 | 
						|
            # Check Python scalars such as int, float, and bool.
 | 
						|
            if isinstance(arg, (int, float, bool)):
 | 
						|
                # Map, e.g., float to onnx.TensorProto.FLOAT.
 | 
						|
                onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg))
 | 
						|
                if onnx_dtype != value_info.type.tensor_type.elem_type:
 | 
						|
                    return False
 | 
						|
                if len(value_info.type.tensor_type.shape.dim) != 0:
 | 
						|
                    return False
 | 
						|
                continue
 | 
						|
 | 
						|
            # Check tensor.
 | 
						|
            onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype]
 | 
						|
            if onnx_dtype != value_info.type.tensor_type.elem_type:
 | 
						|
                return False
 | 
						|
            for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
 | 
						|
                if isinstance(dim, int) and (
 | 
						|
                    onnx_dim.dim_value == dim or onnx_dim.dim_param
 | 
						|
                ):
 | 
						|
                    continue
 | 
						|
                elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
 | 
						|
                    continue
 | 
						|
                else:
 | 
						|
                    return False
 | 
						|
        return True
 | 
						|
 | 
						|
 | 
						|
@dataclasses.dataclass
 | 
						|
class OrtExecutionInfoForAllGraphModules:
 | 
						|
    def __init__(self):
 | 
						|
        # All sessions (and their related information) created by exporting the same GraphModule
 | 
						|
        # with different inputs.
 | 
						|
        self.execution_info_per_graph_module: Dict[
 | 
						|
            torch.fx.GraphModule, List[OrtExecutionInfoPerSession]
 | 
						|
        ] = {}
 | 
						|
 | 
						|
    def search_reusable_session_execution_info(
 | 
						|
        self, graph_module: torch.fx.GraphModule, *args
 | 
						|
    ):
 | 
						|
        if graph_module not in self.execution_info_per_graph_module:
 | 
						|
            return None
 | 
						|
        # All execution information for ONNX models exported from the same `graph_module`
 | 
						|
        # with different inputs.
 | 
						|
        candidates = self.execution_info_per_graph_module[graph_module]
 | 
						|
 | 
						|
        for candidate in candidates:
 | 
						|
            if candidate.is_supported(*args):
 | 
						|
                # Returns the first session that accepts this input schema.
 | 
						|
                return candidate
 | 
						|
        # No reusable session found.
 | 
						|
        return None
 | 
						|
 | 
						|
    def cache_session_execution_info(
 | 
						|
        self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession
 | 
						|
    ):
 | 
						|
        if graph_module not in self.execution_info_per_graph_module:
 | 
						|
            self.execution_info_per_graph_module[graph_module] = [info]
 | 
						|
        else:
 | 
						|
            self.execution_info_per_graph_module[graph_module].append(info)
 | 
						|
 | 
						|
 | 
						|
OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]]
 | 
						|
"""Either the name of an ONNX Runtime execution provider as a string or
 | 
						|
a 2-tuple of the name and a dictionary of execution provider options.
 | 
						|
 | 
						|
Examples::
 | 
						|
 | 
						|
    >>> "CPUExecutionProvider"
 | 
						|
 | 
						|
    >>> ("CUDAExecutionProvider", {"device_id": 3})
 | 
						|
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
@dataclasses.dataclass(frozen=True)
 | 
						|
@compatibility(is_backward_compatible=False)
 | 
						|
class OrtBackendOptions:
 | 
						|
    """Options for constructing an ``OrtBackend``, the ONNX Runtime
 | 
						|
    backend (``"onnxrt"``) for ``torch.compile``.
 | 
						|
 | 
						|
    Example::
 | 
						|
 | 
						|
        >>> @torch.compile(
 | 
						|
        ...     backend="onnxrt",
 | 
						|
        ...     options=torch.onnx._OrtBackendOptions(...),
 | 
						|
        ... )
 | 
						|
        ... def ort_function(x):
 | 
						|
        ...     return x ** x
 | 
						|
    """
 | 
						|
 | 
						|
    preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
 | 
						|
    """An optional sequence of execution providers to be prioritized ahead of any
 | 
						|
    execution providers that may be inferred (see ``infer_execution_providers``).
 | 
						|
    """
 | 
						|
 | 
						|
    infer_execution_providers: bool = True
 | 
						|
    """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph."""
 | 
						|
 | 
						|
    default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
 | 
						|
    """The default fallback execution providers. If not specified, one will be
 | 
						|
    be selected based on the host environment (most likely ``"CPUExecutionProvider"``).
 | 
						|
    """
 | 
						|
 | 
						|
    # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
 | 
						|
    # in order to avoid internal allocation of output buffers in InferenceSession.
 | 
						|
    # If output ortvalue returned from InferenceSession is allocated internally,
 | 
						|
    # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
 | 
						|
    # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
 | 
						|
    # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
 | 
						|
    # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
 | 
						|
    # and use the preallocated output buffers for InferenceSession not holding any ownership for them.
 | 
						|
    # TODO(wschin): Make it to inference session level flag.
 | 
						|
    # See https://github.com/pytorch/pytorch/issues/106869.
 | 
						|
    preallocate_output: bool = False
 | 
						|
    """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side."""
 | 
						|
 | 
						|
    use_aot_autograd: bool = True
 | 
						|
    """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend
 | 
						|
    to support training (i.e., backward graphs are also sent to ``OrtBackend``).
 | 
						|
 | 
						|
    Symbolic execution is used to capture the forward pass and backward passes as a single graph.
 | 
						|
    Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used
 | 
						|
    to split the entire graph into forward sub-graph and backward sub-graph. Finally, both
 | 
						|
    sub-graphs are compiled by ``OrtBackend``.
 | 
						|
    """
 | 
						|
 | 
						|
    export_options: Optional["torch.onnx.ExportOptions"] = None
 | 
						|
    """Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``."""
 | 
						|
 | 
						|
    ort_session_options: Optional["onnxruntime.SessionOptions"] = None
 | 
						|
    """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``."""
 | 
						|
 | 
						|
    pre_ort_model_transforms: Optional[  # type: ignore[name-defined]
 | 
						|
        Sequence[Callable[["onnx.ModelProto"], None]]
 | 
						|
    ] = None
 | 
						|
    """A list of graph transforms to be applied to the ONNX model before it
 | 
						|
    is fed to ONNXRuntime's InferenceSession."""
 | 
						|
 | 
						|
 | 
						|
@compatibility(is_backward_compatible=False)
 | 
						|
class OrtBackend:
 | 
						|
    """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.
 | 
						|
 | 
						|
    The compiler entry point is OrtBackend.compile, which
 | 
						|
        1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
 | 
						|
           sub-graphs.
 | 
						|
        2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
 | 
						|
        3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, options: Optional[OrtBackendOptions] = None):
 | 
						|
        self._options: Final = OrtBackendOptions() if options is None else options
 | 
						|
 | 
						|
        # options.export_options contains information shared between exporter and DORT.
 | 
						|
        # For example, they should use the same decomposition table when
 | 
						|
        #  1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
 | 
						|
        #  2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
 | 
						|
        #     (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
 | 
						|
        #
 | 
						|
        # Convert user-facing option to internal option used by ONNX exporter
 | 
						|
        # to access required information.
 | 
						|
        # Some useful fields:
 | 
						|
        # - Decomposition table for decomposing FX operators in exporter is
 | 
						|
        #   self._resolved_onnx_exporter_options.decomposition_table.
 | 
						|
        # - self._resolved_onnx_exporter_options.onnx_registry records what
 | 
						|
        #   aten/prim ops are supported by exporter and their exporters (type: callable).
 | 
						|
        self._resolved_onnx_exporter_options = (
 | 
						|
            torch.onnx._internal.exporter.ResolvedExportOptions(
 | 
						|
                torch.onnx.ExportOptions()
 | 
						|
                if self._options.export_options is None
 | 
						|
                else self._options.export_options
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        #  Given DORT's computation flow:
 | 
						|
        #   1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators
 | 
						|
        #      and send them to DORT.
 | 
						|
        #   2. Then, DORT exports the selected sub-graphs into ONNX.
 | 
						|
        #   3. Finally DORT calls ORT to do the computation.
 | 
						|
        #  OrtOperatorSupport and create_onnx_friendly_decomposition_table(...)
 | 
						|
        #  must use the same support_dict. If the support_dict here contains something not
 | 
						|
        #  supported by exporter, exporter will fails in step 2 since the selected graphs may
 | 
						|
        #  contains unsupported operators such as aten::_who_you_are.
 | 
						|
        #  This restriction is automatically done since DORT and exporter shares the same
 | 
						|
        #  self._resolved_onnx_exporter_options.
 | 
						|
        support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
 | 
						|
            self._resolved_onnx_exporter_options.onnx_registry
 | 
						|
        )
 | 
						|
 | 
						|
        extra_support_dict: Dict[str, Any] = {
 | 
						|
            "getattr": None,
 | 
						|
            # To send operator.getitem to ORT, add the corresponding string
 | 
						|
            # recognized by PyTorch's OperatorSupport class.
 | 
						|
            "_operator.getitem": None,
 | 
						|
            # To send operator.mul to ORT, add the corresponding string
 | 
						|
            # recognized by PyTorch's OperatorSupport class.
 | 
						|
            "_operator.mul": None,
 | 
						|
            "_operator.add": None,
 | 
						|
            "_operator.sub": None,
 | 
						|
        }
 | 
						|
 | 
						|
        self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
 | 
						|
        # TODO(wschin): this is a naive implementation of cache without proper guard
 | 
						|
        # See https://github.com/pytorch/pytorch/issues/106868.
 | 
						|
        self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
 | 
						|
        # Conceptually, this filed is a 2-layer dictionary
 | 
						|
        #   GraphModule 0
 | 
						|
        #     ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
 | 
						|
        #     ONNX Model 1
 | 
						|
        #     ...
 | 
						|
        #   GraphModule 1
 | 
						|
        #     ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
 | 
						|
        #     ONNX Model 3
 | 
						|
        #     ...
 | 
						|
        #   ...
 | 
						|
        # , which caches all previous compilation result so that we can reuse them.
 | 
						|
        # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
 | 
						|
        # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
 | 
						|
        # graphs captured by Dynamo and sent to OrtBackend.compile.
 | 
						|
        self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()
 | 
						|
 | 
						|
        self._assert_allclose_to_baseline = False
 | 
						|
 | 
						|
        self.execution_count = 0
 | 
						|
 | 
						|
        # Function which invokes ORT do to the real computation.
 | 
						|
        self.run = (
 | 
						|
            _run_onnx_session_with_ortvaluevector
 | 
						|
            if hasattr(ORTC.OrtValueVector, "push_back_batch")
 | 
						|
            else _run_onnx_session_with_fetch
 | 
						|
        )
 | 
						|
 | 
						|
    def _select_eps(
 | 
						|
        self, graph_module: torch.fx.GraphModule, *args
 | 
						|
    ) -> Sequence[Tuple[str, Mapping[str, Any]]]:
 | 
						|
        inferred_eps: Tuple[str, ...] = ()
 | 
						|
        if self._options.infer_execution_providers:
 | 
						|
            if eps_from_args := _infer_ep_from_device(*args):
 | 
						|
                # If user feeds CUDA tensor as input argument,
 | 
						|
                # we want to use CUDA EP.
 | 
						|
                # Thus, `eps_from_args` (deduced from input arguments)
 | 
						|
                # has highest priority.
 | 
						|
                inferred_eps = eps_from_args
 | 
						|
            elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module):
 | 
						|
                # If there is no EP in input arguments, we deduce EP from
 | 
						|
                # graph_module's outputs. Those outputs may come from
 | 
						|
                # FakeTensorProp or Dynamo's built-in symbolic shape inference.
 | 
						|
                inferred_eps = eps_from_graph_module
 | 
						|
 | 
						|
        selected_eps = []
 | 
						|
 | 
						|
        for ep in (
 | 
						|
            *(self._options.preferred_execution_providers or []),
 | 
						|
            *_sort_eps(inferred_eps),
 | 
						|
            *(self._options.default_execution_providers or _infer_default_eps()),
 | 
						|
        ):
 | 
						|
            if isinstance(ep, str):
 | 
						|
                ep = (ep, {})
 | 
						|
            elif isinstance(ep, tuple) and ep[1] is None:
 | 
						|
                ep = (ep[0], {})
 | 
						|
            if ep is not None and ep not in selected_eps:
 | 
						|
                selected_eps.append(ep)
 | 
						|
 | 
						|
        return selected_eps
 | 
						|
 | 
						|
    def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
 | 
						|
        """This function replaces GraphModule._wrapped_call in compiled model.
 | 
						|
 | 
						|
        The _wrapped_call is the underlying implementation of forward method. Replacing
 | 
						|
        it means we delegate the computation to _ort_acclerated_call and therefore
 | 
						|
        onnxruntime.InferenceSession.
 | 
						|
        """
 | 
						|
        cached_execution_info_per_session = (
 | 
						|
            self._all_ort_execution_info.search_reusable_session_execution_info(
 | 
						|
                graph_module, *args
 | 
						|
            )
 | 
						|
        )
 | 
						|
        if cached_execution_info_per_session:
 | 
						|
            onnx_session = cached_execution_info_per_session.session
 | 
						|
            input_names = cached_execution_info_per_session.input_names
 | 
						|
            output_names = cached_execution_info_per_session.output_names
 | 
						|
            input_value_infos = cached_execution_info_per_session.input_value_infos
 | 
						|
            output_value_infos = cached_execution_info_per_session.output_value_infos
 | 
						|
            input_devices = cached_execution_info_per_session.input_devices
 | 
						|
            output_devices = cached_execution_info_per_session.output_devices
 | 
						|
            prim_outputs = cached_execution_info_per_session.example_outputs
 | 
						|
        else:
 | 
						|
            # It's first time seeing such as graph. Let's make a new session
 | 
						|
            # (type: onnxruntime.InferenceSession) for it.
 | 
						|
 | 
						|
            graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
 | 
						|
                self._resolved_onnx_exporter_options.diagnostic_context,
 | 
						|
                graph_module,
 | 
						|
            ).run()
 | 
						|
            # Generate reference outputs. They are used to indicate output
 | 
						|
            # tensors' types and devices when calling ORT.
 | 
						|
            #
 | 
						|
            # WARNING: The downstream code should not change prim_outputs and
 | 
						|
            # this backend should always produces output with schema identical to prim_outputs'.
 | 
						|
 | 
						|
            if self._resolved_onnx_exporter_options.dynamic_shapes:
 | 
						|
                # No pre-allocation when dynamic shape is enabled.
 | 
						|
                self.preallocate_output = False
 | 
						|
                extracted_outputs = _extract_graph_module_outputs(graph_module)
 | 
						|
 | 
						|
                def maybe_map_to_meta_val(value):
 | 
						|
                    if hasattr(value, "meta") and "val" in value.meta:
 | 
						|
                        # Select outputs with "val" information. Without "val",
 | 
						|
                        # it's not possible access output_arg.meta["val"].device.
 | 
						|
                        return value.meta["val"]
 | 
						|
                    else:
 | 
						|
                        return value
 | 
						|
 | 
						|
                prim_outputs = _pytree.tree_map(
 | 
						|
                    maybe_map_to_meta_val, extracted_outputs
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                try:
 | 
						|
                    prim_outputs = FakeTensorProp(graph_module).propagate(
 | 
						|
                        *args, **kwargs
 | 
						|
                    )
 | 
						|
                except Exception:
 | 
						|
                    logger.warning("FakeTensorProb failed for %s", graph_module)
 | 
						|
                    # When FakeTensorProp fails, it is not possible to preallocate output buffers
 | 
						|
                    # because the output shapes are not inferred.
 | 
						|
                    self.preallocate_output = False
 | 
						|
 | 
						|
                    # rethrow FakeTensorProb failure because it is not yet currently handled.
 | 
						|
                    raise
 | 
						|
 | 
						|
            # Create the object to iterate through the nodes in graph one-by-one
 | 
						|
            # and calls the corresponding ONNX exporter for each node.
 | 
						|
            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
 | 
						|
                diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
 | 
						|
            )
 | 
						|
            # Cast FX variables if they will result schema-mismatch when searching
 | 
						|
            # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
 | 
						|
            # but ONNX expects add(double_tensor, double_tensor).
 | 
						|
            graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
 | 
						|
                self._resolved_onnx_exporter_options.diagnostic_context, graph_module
 | 
						|
            ).run()
 | 
						|
            # Start the per-node exporting process. It's conceptually a for loop
 | 
						|
            # scanning through the nodes in the graph.
 | 
						|
            exported = fx_interpreter.run(
 | 
						|
                fx_graph_module=graph_module,
 | 
						|
                onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
 | 
						|
                op_level_debug=self._resolved_onnx_exporter_options.op_level_debug,
 | 
						|
            )
 | 
						|
            # Convert the exported result to ONNX ModelProto.
 | 
						|
            onnx_model = exported.to_model_proto(
 | 
						|
                opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
 | 
						|
            )
 | 
						|
 | 
						|
            try:
 | 
						|
                from onnxscript import optimizer  # type: ignore[import]
 | 
						|
                from onnxscript.rewriter import (  # type: ignore[import]
 | 
						|
                    onnxruntime as ort_rewriter,  # type: ignore[import]
 | 
						|
                )
 | 
						|
 | 
						|
                onnx_model = optimizer.optimize(onnx_model)
 | 
						|
                onnx_model = ort_rewriter.rewrite(onnx_model)
 | 
						|
            except ImportError:
 | 
						|
                logger.warning(
 | 
						|
                    "ONNXScript optimizer is not available. Skipping optimization. "
 | 
						|
                    "Please `pip install onnxscript -U` to enable post-export optimization."
 | 
						|
                )
 | 
						|
 | 
						|
            # Modify ONNX model using pre-registered graph transforms.
 | 
						|
            # They are in-place modifications for avoiding unnecessary
 | 
						|
            # copy of ONNX initializers.
 | 
						|
            if self._options.pre_ort_model_transforms:
 | 
						|
                for transform in self._options.pre_ort_model_transforms:
 | 
						|
                    transform(onnx_model)
 | 
						|
 | 
						|
            onnx_model_bytes = onnx_model.SerializeToString()
 | 
						|
            if os.environ.get("ONNXRT_DUMP_PATH", None):
 | 
						|
                # If not empty, environment variable ONNXRT_DUMP_PATH defined the path
 | 
						|
                # where generated onnx files should be stored.
 | 
						|
                # This module keeps a global variables keeping track of the
 | 
						|
                # stored models.
 | 
						|
                # If ONNXRT_DUMP_PATH="dumped/dumped_model_"
 | 
						|
                # The first file name will be 'dumped/dumped_model_0.onnx'.
 | 
						|
                # For every dumped model, a text file 'dumped/dumped_model_0.txt'
 | 
						|
                # is created as well to contain the string representing the graph_module.
 | 
						|
                _dump_onnx_model(onnx_model_bytes, graph_module=graph_module)
 | 
						|
 | 
						|
            # Initialize a ORT session to execute this ONNX model.
 | 
						|
            # Note that TorchDynamo assumes all inputs/outputs are on the
 | 
						|
            # same device, but it's subject to change (very likely with
 | 
						|
            # dynamic shape support), so we add execution providers
 | 
						|
            # based on the logic in _select_eps: (explicitly preferred EPs,
 | 
						|
            # EPs inferred from inputs or graph, and the fallback default EP)/
 | 
						|
            #
 | 
						|
            # TODO(wschin): enable external allocators.
 | 
						|
            # See https://github.com/pytorch/pytorch/issues/106867
 | 
						|
            onnx_session = onnxruntime.InferenceSession(
 | 
						|
                path_or_bytes=onnx_model_bytes,
 | 
						|
                sess_options=self._options.ort_session_options,
 | 
						|
                providers=self._select_eps(graph_module, *args),
 | 
						|
            )
 | 
						|
 | 
						|
            # Cache ORT session. It's reused for the same "graph_module".
 | 
						|
            # Generate ONNX model and extract its input and output names.
 | 
						|
            input_names = tuple(input.name for input in onnx_model.graph.input)
 | 
						|
            output_names = tuple(output.name for output in onnx_model.graph.output)
 | 
						|
            input_devices = _get_onnx_devices(args)
 | 
						|
            # Cache devices for inputs and outputs. They are used to invoke
 | 
						|
            # ORT session. Output devices indicate where (e.g., GPU or CPU)
 | 
						|
            # to store outputs
 | 
						|
            if isinstance(prim_outputs, tuple):
 | 
						|
                output_devices = _get_onnx_devices(prim_outputs)
 | 
						|
            else:
 | 
						|
                output_devices = _get_onnx_devices((prim_outputs,))
 | 
						|
 | 
						|
            input_value_infos = tuple(input for input in onnx_model.graph.input)
 | 
						|
            output_value_infos = tuple(output for output in onnx_model.graph.output)
 | 
						|
 | 
						|
            execution_info_per_session = OrtExecutionInfoPerSession(
 | 
						|
                session=onnx_session,
 | 
						|
                input_names=input_names,
 | 
						|
                input_value_infos=input_value_infos,
 | 
						|
                output_names=output_names,
 | 
						|
                output_value_infos=output_value_infos,
 | 
						|
                input_devices=input_devices,
 | 
						|
                output_devices=output_devices,
 | 
						|
                example_outputs=prim_outputs,
 | 
						|
            )
 | 
						|
 | 
						|
            self._all_ort_execution_info.cache_session_execution_info(
 | 
						|
                graph_module, execution_info_per_session
 | 
						|
            )
 | 
						|
 | 
						|
        self.execution_count += 1
 | 
						|
 | 
						|
        # ORT always returns a tuple of outputs. If the original output is a tensor,
 | 
						|
        # ORT output's first element must be extracted and returned. Otherwise, type
 | 
						|
        # mismatch may happen in downstream computation.
 | 
						|
        is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
 | 
						|
        normalized_prim_outputs = (
 | 
						|
            (prim_outputs,) if is_single_tensor_output else prim_outputs
 | 
						|
        )
 | 
						|
        assert isinstance(normalized_prim_outputs, tuple)
 | 
						|
        assert all(
 | 
						|
            isinstance(elem, (torch.Tensor, torch.SymInt, int))
 | 
						|
            for elem in normalized_prim_outputs
 | 
						|
        )
 | 
						|
 | 
						|
        _nvtx_range_push("run_onnx_session_with_ortvaluevector")
 | 
						|
        onnx_outputs = self.run(
 | 
						|
            onnx_session,
 | 
						|
            input_names,
 | 
						|
            args,
 | 
						|
            input_devices,
 | 
						|
            output_names,
 | 
						|
            normalized_prim_outputs,
 | 
						|
            output_devices,
 | 
						|
            self._options.preallocate_output,
 | 
						|
            input_value_infos,
 | 
						|
            normalized_prim_outputs,
 | 
						|
        )
 | 
						|
        _nvtx_range_pop()
 | 
						|
 | 
						|
        if self._assert_allclose_to_baseline:
 | 
						|
            # Compute baseline.
 | 
						|
            baseline_outputs = torch._prims.executor.execute(
 | 
						|
                graph_module, *args, executor="aten"
 | 
						|
            )
 | 
						|
            normalized_baseline_ouptuts = (
 | 
						|
                (baseline_outputs,) if is_single_tensor_output else baseline_outputs
 | 
						|
            )
 | 
						|
            # Ensure every output tensor is close to the corresponding baseline.
 | 
						|
            for onnx_output, baseline_output in zip(
 | 
						|
                onnx_outputs, normalized_baseline_ouptuts
 | 
						|
            ):
 | 
						|
                torch.testing.assert_close(onnx_output, baseline_output)
 | 
						|
        return onnx_outputs[0] if is_single_tensor_output else onnx_outputs
 | 
						|
 | 
						|
    def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
 | 
						|
        # Deferred import since CapabilityBasedPartitioner is not decorated with
 | 
						|
        # @compatibility; importing it at the module level will result in the test
 | 
						|
        # failing: pytest test/test_fx.py -k test_public_api_surface
 | 
						|
        # because this module is imported into torch.onnx.
 | 
						|
        from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
 | 
						|
 | 
						|
        # FX graph based partitioning based on ONNX supported ops.
 | 
						|
        # Given a graph module
 | 
						|
        #  GraphModule0
 | 
						|
        #   node_0
 | 
						|
        #   node_1
 | 
						|
        #   node_2
 | 
						|
        #   node_3
 | 
						|
        #   node_4
 | 
						|
        # If only node_2 is not supported by ONNX, this graph module will be partitioned into
 | 
						|
        #  GraphModule0
 | 
						|
        #   GraphModule1
 | 
						|
        #    node_0
 | 
						|
        #    node_1
 | 
						|
        #   node_2
 | 
						|
        #   GraphModule2
 | 
						|
        #    node_3
 | 
						|
        #    node_4
 | 
						|
        # by calling CapabilityBasedPartitioner.partition_and_fuse.
 | 
						|
        # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call)
 | 
						|
        # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT.
 | 
						|
        if graph_module in self._partitioner_cache:
 | 
						|
            partitioned_prim_graph_module = self._partitioner_cache[graph_module]
 | 
						|
        else:
 | 
						|
            prim_graph_module = graph_module
 | 
						|
            partitioner = CapabilityBasedPartitioner(
 | 
						|
                prim_graph_module,
 | 
						|
                self._supported_ops,
 | 
						|
                allows_single_node_partition=True,
 | 
						|
            )
 | 
						|
            partitioned_prim_graph_module = partitioner.partition_and_fuse()
 | 
						|
            self._partitioner_cache[graph_module] = partitioned_prim_graph_module
 | 
						|
 | 
						|
            # Overriding fused_module's __call__() function with ort_acclerated_call()
 | 
						|
            # This loop goes through all graph partitions (each of them is an ONNX-representable graph)
 | 
						|
            # and override their _wrapped_call function with _ort_accelerated_call.
 | 
						|
            # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
 | 
						|
            for node in partitioned_prim_graph_module.graph.nodes:
 | 
						|
                # TODO(wschin): use a better way to identify fused submodule
 | 
						|
                # See https://github.com/pytorch/pytorch/issues/106872.
 | 
						|
                if node.op == "call_module" and "fused_" in node.name:
 | 
						|
                    fused_module = getattr(partitioned_prim_graph_module, node.name)
 | 
						|
                    # self.ort_acclerated_call is responsible for exporting graph to ONNX,
 | 
						|
                    # creating ORT session, and running ORT session.
 | 
						|
                    fused_module._wrapped_call = self._ort_acclerated_call
 | 
						|
 | 
						|
        return partitioned_prim_graph_module
 | 
						|
 | 
						|
    def __call__(
 | 
						|
        self, graph_module: torch.fx.GraphModule, args
 | 
						|
    ) -> torch.fx.GraphModule:
 | 
						|
        """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
 | 
						|
        will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
 | 
						|
        the ``compile`` method is invoked directly."""
 | 
						|
        if self._options.use_aot_autograd:
 | 
						|
            from functorch.compile import min_cut_rematerialization_partition
 | 
						|
 | 
						|
            from torch._dynamo.backends.common import aot_autograd
 | 
						|
 | 
						|
            return aot_autograd(
 | 
						|
                fw_compiler=self.compile,
 | 
						|
                partition_fn=min_cut_rematerialization_partition,
 | 
						|
                decompositions=self._resolved_onnx_exporter_options.decomposition_table,
 | 
						|
            )(graph_module, args)
 | 
						|
 | 
						|
        return self.compile(graph_module, args)
 | 
						|
 | 
						|
    __instance_cache_max_count: Final = 8
 | 
						|
    __instance_cache: Final[List["OrtBackend"]] = []
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_cached_instance_for_options(
 | 
						|
        options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
 | 
						|
    ) -> "OrtBackend":
 | 
						|
        """Returns a possibly cached instance of an ``OrtBackend``. If an existing
 | 
						|
        backend was created previously through this function with the same options,
 | 
						|
        it will be returned. Otherwise a new backend will be created, cached, and
 | 
						|
        returned.
 | 
						|
 | 
						|
        Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
 | 
						|
        will always be returned, since ``onnxruntime.SessionOptions`` cannot
 | 
						|
        participate in caching."""
 | 
						|
 | 
						|
        def reusable(a: OrtBackendOptions, b: OrtBackendOptions):
 | 
						|
            if (
 | 
						|
                a.preferred_execution_providers != b.preferred_execution_providers
 | 
						|
                or a.infer_execution_providers != b.infer_execution_providers
 | 
						|
                or a.default_execution_providers != b.default_execution_providers
 | 
						|
                or a.preallocate_output != b.preallocate_output
 | 
						|
                or a.use_aot_autograd != b.use_aot_autograd
 | 
						|
                or a.pre_ort_model_transforms != b.pre_ort_model_transforms
 | 
						|
            ):
 | 
						|
                return False
 | 
						|
 | 
						|
            # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled,
 | 
						|
            # and holds too much potential state to reasonably check manually;
 | 
						|
            # ort_session_options is provided at all, the backend does not participate
 | 
						|
            # in caching.
 | 
						|
            if a.ort_session_options is not None or b.ort_session_options is not None:
 | 
						|
                return False
 | 
						|
 | 
						|
            if a.export_options is b.export_options:
 | 
						|
                return True
 | 
						|
 | 
						|
            # Similarly, some objects in ExportOptions are too stateful to use for
 | 
						|
            # caching. We should revisit this.
 | 
						|
            if a.export_options is not None and b.export_options is not None:
 | 
						|
                return (
 | 
						|
                    a.export_options.dynamic_shapes == b.export_options.dynamic_shapes
 | 
						|
                    and a.export_options.op_level_debug
 | 
						|
                    == b.export_options.op_level_debug
 | 
						|
                    and a.export_options.diagnostic_options
 | 
						|
                    == b.export_options.diagnostic_options
 | 
						|
                    and a.export_options.onnx_registry is b.export_options.onnx_registry
 | 
						|
                    and a.export_options.fake_context is b.export_options.fake_context
 | 
						|
                )
 | 
						|
 | 
						|
            # We can't account for how the two option sets may differ, so it's not safe to reuse.
 | 
						|
            return False
 | 
						|
 | 
						|
        if not isinstance(options, OrtBackendOptions):
 | 
						|
            options = OrtBackendOptions(**(options or {}))
 | 
						|
 | 
						|
        backend = next(
 | 
						|
            (b for b in OrtBackend.__instance_cache if reusable(b._options, options)),
 | 
						|
            None,
 | 
						|
        )
 | 
						|
 | 
						|
        if backend is None:
 | 
						|
            assert (
 | 
						|
                len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count
 | 
						|
            ), (
 | 
						|
                f"No more than {OrtBackend.__instance_cache_max_count} instances of "
 | 
						|
                f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly "
 | 
						|
                "to pass to `torch.compile`. "
 | 
						|
                "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 "
 | 
						|
                "for discussion."
 | 
						|
            )
 | 
						|
            OrtBackend.__instance_cache.append(backend := OrtBackend(options))
 | 
						|
 | 
						|
        return backend
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def clear_cached_instances():
 | 
						|
        OrtBackend.__instance_cache.clear()
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_cached_instances():
 | 
						|
        return tuple(OrtBackend.__instance_cache)
 | 
						|
 | 
						|
 | 
						|
@compatibility(is_backward_compatible=False)
 | 
						|
def torch_compile_backend(
 | 
						|
    graph_module: torch.fx.GraphModule,
 | 
						|
    args,
 | 
						|
    *,
 | 
						|
    options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
 | 
						|
):
 | 
						|
    return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)
 |