[AOTInductor] Add class declarations to torch._C._aoti interface file (#155128)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155128
Approved by: https://github.com/desertfire
ghstack dependencies: #155149
This commit is contained in:
Benjamin Glass
2025-06-13 16:30:48 +00:00
committed by PyTorch MergeBot
parent 82fb904140
commit 4311aea5e7
5 changed files with 192 additions and 45 deletions

View File

@ -5,7 +5,7 @@ import os
import shutil
import tempfile
import types
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch._export
@ -21,6 +21,10 @@ from torch.testing._internal.inductor_utils import clone_preserve_strides_offset
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from torch._C._aoti import AOTIModelContainerRunner
class WrapperModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
@ -73,7 +77,7 @@ class AOTIRunnerUtil:
return so_path
@staticmethod
def legacy_load_runner(device, so_path):
def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner":
if IS_FBCODE:
from .fb import test_aot_inductor_model_runner_pybind # @manual

View File

@ -1,4 +1,5 @@
from ctypes import c_void_p
from typing import overload, Protocol
from torch import Tensor
@ -16,10 +17,148 @@ def alloc_tensor_by_stealing_from_void_ptr(
handle: c_void_p,
) -> Tensor: ...
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...
class AOTIModelContainerRunnerXpu: ...
class AOTIModelContainerRunnerMps: ...
class AOTIModelContainerRunner(Protocol):
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...
def swap_constant_buffer(self) -> None: ...
def free_inactive_constant_buffer(self) -> None: ...
class AOTIModelContainerRunnerCpu:
def __init__(self, model_so_path: str, num_models: int) -> None: ...
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...
def swap_constant_buffer(self) -> None: ...
def free_inactive_constant_buffer(self) -> None: ...
class AOTIModelContainerRunnerCuda:
@overload
def __init__(self, model_so_path: str, num_models: int) -> None: ...
@overload
def __init__(
self, model_so_path: str, num_models: int, device_str: str
) -> None: ...
@overload
def __init__(
self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
) -> None: ...
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...
def swap_constant_buffer(self) -> None: ...
def free_inactive_constant_buffer(self) -> None: ...
class AOTIModelContainerRunnerXpu:
@overload
def __init__(self, model_so_path: str, num_models: int) -> None: ...
@overload
def __init__(
self, model_so_path: str, num_models: int, device_str: str
) -> None: ...
@overload
def __init__(
self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
) -> None: ...
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...
def swap_constant_buffer(self) -> None: ...
def free_inactive_constant_buffer(self) -> None: ...
class AOTIModelContainerRunnerMps:
def __init__(self, model_so_path: str, num_models: int) -> None: ...
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...
def swap_constant_buffer(self) -> None: ...
def free_inactive_constant_buffer(self) -> None: ...
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
class AOTIModelPackageLoader: ...
class AOTIModelPackageLoader:
def __init__(
self,
model_package_path: str,
model_name: str,
run_single_threaded: bool,
num_runners: int,
device_index: int,
) -> None: ...
def get_metadata(self) -> dict[str, str]: ...
def run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def boxed_run(
self, inputs: list[Tensor], stream_handle: c_void_p = ...
) -> list[Tensor]: ...
def get_call_spec(self) -> list[str]: ...
def get_constant_fqns(self) -> list[str]: ...
def load_constants(
self,
constants_map: dict[str, Tensor],
use_inactive: bool,
check_full_update: bool,
user_managed: bool = ...,
) -> None: ...
def update_constant_buffer(
self,
tensor_map: dict[str, Tensor],
use_inactive: bool,
validate_full_updates: bool,
user_managed: bool = ...,
) -> None: ...

View File

@ -16,7 +16,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from unittest.mock import patch
import torch
@ -48,6 +48,9 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules
from .utils import _materialize_cpp_cia_ops
if TYPE_CHECKING:
from torch._C._aoti import AOTIModelContainerRunner
log = logging.getLogger(__name__)
@dataclasses.dataclass
@ -160,23 +163,23 @@ def aot_load(so_path: str, device: str) -> Callable:
aot_compile_warning()
if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
runner: AOTIModelContainerRunner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
elif device == "xpu" or device.startswith("xpu:"):
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device)
elif device == "mps" or device.startswith("mps:"):
runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) # type: ignore[assignment, call-arg]
runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1)
else:
raise RuntimeError("Unsupported device " + device)
def optimized(*args, **kwargs):
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
call_spec = runner.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
flat_outputs = runner.run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized

View File

@ -120,11 +120,11 @@ def load_package(
log.debug("Writing buffer to tmp file located at %s.", f.name)
loader = torch._C._aoti.AOTIModelPackageLoader(
f.name, model_name, run_single_threaded, num_runners, device_index
) # type: ignore[call-arg]
)
return AOTICompiledModel(loader)
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader(
path, model_name, run_single_threaded, num_runners, device_index
) # type: ignore[call-arg]
)
return AOTICompiledModel(loader)

View File

@ -8,7 +8,6 @@ from dataclasses import dataclass
from typing import Any, IO, Optional, Union
import torch
import torch._inductor
import torch.utils._pytree as pytree
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
from torch.export._tree_utils import reorder_kwargs
@ -366,16 +365,16 @@ class AOTICompiledModel:
self.loader = loader
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
call_spec = self.loader.get_call_spec()
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined]
flat_outputs = self.loader.boxed_run(flat_inputs)
return pytree.tree_unflatten(flat_outputs, out_spec)
def get_metadata(self) -> dict[str, str]:
return self.loader.get_metadata() # type: ignore[attr-defined]
return self.loader.get_metadata()
def load_constants(
self,
@ -394,18 +393,18 @@ class AOTICompiledModel:
check_full_update: Whether to add check to see if all the constants
are updated and have values.
"""
self.loader.load_constants( # type: ignore[attr-defined]
self.loader.load_constants(
constants_map, False, check_full_update, user_managed
)
def get_constant_fqns(self) -> list[str]:
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
return self.loader.get_constant_fqns()
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
logger.warning(
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
)
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
return AOTICompiledModel(self.loader)
@dataclass
@ -531,7 +530,7 @@ def load_pt2(
extra_files = _load_extra_files(archive_reader, file_names)
# Get a list of AOTI model names
aoti_model_names = set()
aoti_model_names: set[str] = set()
for file in file_names:
if file.startswith(AOTINDUCTOR_DIR):
file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix
@ -540,33 +539,35 @@ def load_pt2(
] # split "model_name/...cpp" into "model_name"
aoti_model_names.add(model_name)
if isinstance(f, (io.IOBase, IO)) and len(aoti_model_names) > 0:
# Workaround for AOTIModelPackageLoader not reading buffers
with tempfile.NamedTemporaryFile(suffix=".pt2") as tf:
f.seek(0)
tf.write(f.read())
f.seek(0)
logger.debug("Writing buffer to tmp file located at %s.", tf.name)
aoti_runners = {
model_name: AOTICompiledModel(
torch._C._aoti.AOTIModelPackageLoader(
tf.name,
model_name,
run_single_threaded,
num_runners,
device_index,
) # type: ignore[call-arg]
)
for model_name in aoti_model_names
}
if isinstance(f, (io.IOBase, IO)):
if len(aoti_model_names) > 0:
# Workaround for AOTIModelPackageLoader not reading buffers
with tempfile.NamedTemporaryFile(suffix=".pt2") as tf:
f.seek(0)
tf.write(f.read())
f.seek(0)
logger.debug("Writing buffer to tmp file located at %s.", tf.name)
aoti_runners = {
model_name: AOTICompiledModel(
torch._C._aoti.AOTIModelPackageLoader(
tf.name,
model_name,
run_single_threaded,
num_runners,
device_index,
)
)
for model_name in aoti_model_names
}
else:
aoti_runners = {}
else:
aoti_runners = {
model_name: AOTICompiledModel(
torch._C._aoti.AOTIModelPackageLoader(
f, model_name, run_single_threaded, num_runners, device_index
) # type: ignore[call-arg]
)
)
for model_name in aoti_model_names
}