mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155128 Approved by: https://github.com/desertfire ghstack dependencies: #155149
165 lines
5.9 KiB
Python
165 lines
5.9 KiB
Python
from ctypes import c_void_p
|
|
from typing import overload, Protocol
|
|
|
|
from torch import Tensor
|
|
|
|
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
|
|
|
|
# Tensor to AtenTensorHandle
|
|
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
|
|
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
|
|
|
|
# AtenTensorHandle to Tensor
|
|
def alloc_tensors_by_stealing_from_void_ptrs(
|
|
handles: list[c_void_p],
|
|
) -> list[Tensor]: ...
|
|
def alloc_tensor_by_stealing_from_void_ptr(
|
|
handle: c_void_p,
|
|
) -> Tensor: ...
|
|
|
|
class AOTIModelContainerRunner(Protocol):
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
|
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
|
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def swap_constant_buffer(self) -> None: ...
|
|
def free_inactive_constant_buffer(self) -> None: ...
|
|
|
|
class AOTIModelContainerRunnerCpu:
|
|
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
|
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
|
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def swap_constant_buffer(self) -> None: ...
|
|
def free_inactive_constant_buffer(self) -> None: ...
|
|
|
|
class AOTIModelContainerRunnerCuda:
|
|
@overload
|
|
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
|
@overload
|
|
def __init__(
|
|
self, model_so_path: str, num_models: int, device_str: str
|
|
) -> None: ...
|
|
@overload
|
|
def __init__(
|
|
self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
|
|
) -> None: ...
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
|
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
|
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def swap_constant_buffer(self) -> None: ...
|
|
def free_inactive_constant_buffer(self) -> None: ...
|
|
|
|
class AOTIModelContainerRunnerXpu:
|
|
@overload
|
|
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
|
@overload
|
|
def __init__(
|
|
self, model_so_path: str, num_models: int, device_str: str
|
|
) -> None: ...
|
|
@overload
|
|
def __init__(
|
|
self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
|
|
) -> None: ...
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
|
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
|
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def swap_constant_buffer(self) -> None: ...
|
|
def free_inactive_constant_buffer(self) -> None: ...
|
|
|
|
class AOTIModelContainerRunnerMps:
|
|
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
|
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
|
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def swap_constant_buffer(self) -> None: ...
|
|
def free_inactive_constant_buffer(self) -> None: ...
|
|
|
|
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
|
|
class AOTIModelPackageLoader:
|
|
def __init__(
|
|
self,
|
|
model_package_path: str,
|
|
model_name: str,
|
|
run_single_threaded: bool,
|
|
num_runners: int,
|
|
device_index: int,
|
|
) -> None: ...
|
|
def get_metadata(self) -> dict[str, str]: ...
|
|
def run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def boxed_run(
|
|
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
|
) -> list[Tensor]: ...
|
|
def get_call_spec(self) -> list[str]: ...
|
|
def get_constant_fqns(self) -> list[str]: ...
|
|
def load_constants(
|
|
self,
|
|
constants_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
check_full_update: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|
|
def update_constant_buffer(
|
|
self,
|
|
tensor_map: dict[str, Tensor],
|
|
use_inactive: bool,
|
|
validate_full_updates: bool,
|
|
user_managed: bool = ...,
|
|
) -> None: ...
|