mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTInductor] Add class declarations to torch._C._aoti interface file (#155128)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155128 Approved by: https://github.com/desertfire ghstack dependencies: #155149
This commit is contained in:
committed by
PyTorch MergeBot
parent
82fb904140
commit
4311aea5e7
@ -5,7 +5,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import types
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._export
|
||||
@ -21,6 +21,10 @@ from torch.testing._internal.inductor_utils import clone_preserve_strides_offset
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._aoti import AOTIModelContainerRunner
|
||||
|
||||
|
||||
class WrapperModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
@ -73,7 +77,7 @@ class AOTIRunnerUtil:
|
||||
return so_path
|
||||
|
||||
@staticmethod
|
||||
def legacy_load_runner(device, so_path):
|
||||
def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner":
|
||||
if IS_FBCODE:
|
||||
from .fb import test_aot_inductor_model_runner_pybind # @manual
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from ctypes import c_void_p
|
||||
from typing import overload, Protocol
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@ -16,10 +17,148 @@ def alloc_tensor_by_stealing_from_void_ptr(
|
||||
handle: c_void_p,
|
||||
) -> Tensor: ...
|
||||
|
||||
class AOTIModelContainerRunnerCpu: ...
|
||||
class AOTIModelContainerRunnerCuda: ...
|
||||
class AOTIModelContainerRunnerXpu: ...
|
||||
class AOTIModelContainerRunnerMps: ...
|
||||
class AOTIModelContainerRunner(Protocol):
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
||||
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
||||
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def swap_constant_buffer(self) -> None: ...
|
||||
def free_inactive_constant_buffer(self) -> None: ...
|
||||
|
||||
class AOTIModelContainerRunnerCpu:
|
||||
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
||||
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
||||
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def swap_constant_buffer(self) -> None: ...
|
||||
def free_inactive_constant_buffer(self) -> None: ...
|
||||
|
||||
class AOTIModelContainerRunnerCuda:
|
||||
@overload
|
||||
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, model_so_path: str, num_models: int, device_str: str
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, model_so_path: str, num_models: int, device_str: str, cubin_dir: str
|
||||
) -> None: ...
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
||||
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
||||
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def swap_constant_buffer(self) -> None: ...
|
||||
def free_inactive_constant_buffer(self) -> None: ...
|
||||
|
||||
class AOTIModelContainerRunnerXpu:
|
||||
@overload
|
||||
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, model_so_path: str, num_models: int, device_str: str
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, model_so_path: str, num_models: int, device_str: str, kernel_bin_dir: str
|
||||
) -> None: ...
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
||||
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
||||
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def swap_constant_buffer(self) -> None: ...
|
||||
def free_inactive_constant_buffer(self) -> None: ...
|
||||
|
||||
class AOTIModelContainerRunnerMps:
|
||||
def __init__(self, model_so_path: str, num_models: int) -> None: ...
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_names_to_original_fqns(self) -> dict[str, str]: ...
|
||||
def get_constant_names_to_dtypes(self) -> dict[str, int]: ...
|
||||
def extract_constants_map(self, use_inactive: bool) -> dict[str, Tensor]: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def swap_constant_buffer(self) -> None: ...
|
||||
def free_inactive_constant_buffer(self) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/inductor/aoti_package/pybind.cpp
|
||||
class AOTIModelPackageLoader: ...
|
||||
class AOTIModelPackageLoader:
|
||||
def __init__(
|
||||
self,
|
||||
model_package_path: str,
|
||||
model_name: str,
|
||||
run_single_threaded: bool,
|
||||
num_runners: int,
|
||||
device_index: int,
|
||||
) -> None: ...
|
||||
def get_metadata(self) -> dict[str, str]: ...
|
||||
def run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def boxed_run(
|
||||
self, inputs: list[Tensor], stream_handle: c_void_p = ...
|
||||
) -> list[Tensor]: ...
|
||||
def get_call_spec(self) -> list[str]: ...
|
||||
def get_constant_fqns(self) -> list[str]: ...
|
||||
def load_constants(
|
||||
self,
|
||||
constants_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
check_full_update: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
def update_constant_buffer(
|
||||
self,
|
||||
tensor_map: dict[str, Tensor],
|
||||
use_inactive: bool,
|
||||
validate_full_updates: bool,
|
||||
user_managed: bool = ...,
|
||||
) -> None: ...
|
||||
|
@ -16,7 +16,7 @@ from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -48,6 +48,9 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from .wrappers import _wrap_submodules
|
||||
from .utils import _materialize_cpp_cia_ops
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._aoti import AOTIModelContainerRunner
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -160,23 +163,23 @@ def aot_load(so_path: str, device: str) -> Callable:
|
||||
aot_compile_warning()
|
||||
|
||||
if device == "cpu":
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
|
||||
runner: AOTIModelContainerRunner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
|
||||
elif device == "cuda" or device.startswith("cuda:"):
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)
|
||||
elif device == "xpu" or device.startswith("xpu:"):
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device)
|
||||
elif device == "mps" or device.startswith("mps:"):
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) # type: ignore[assignment, call-arg]
|
||||
runner = torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1)
|
||||
else:
|
||||
raise RuntimeError("Unsupported device " + device)
|
||||
|
||||
def optimized(*args, **kwargs):
|
||||
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
|
||||
call_spec = runner.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
||||
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
||||
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
|
||||
flat_outputs = runner.run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
return optimized
|
||||
|
@ -120,11 +120,11 @@ def load_package(
|
||||
log.debug("Writing buffer to tmp file located at %s.", f.name)
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(
|
||||
f.name, model_name, run_single_threaded, num_runners, device_index
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
return AOTICompiledModel(loader)
|
||||
|
||||
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(
|
||||
path, model_name, run_single_threaded, num_runners, device_index
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
return AOTICompiledModel(loader)
|
||||
|
@ -8,7 +8,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, IO, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
@ -366,16 +365,16 @@ class AOTICompiledModel:
|
||||
self.loader = loader
|
||||
|
||||
def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
call_spec = self.loader.get_call_spec() # type: ignore[attr-defined]
|
||||
call_spec = self.loader.get_call_spec()
|
||||
in_spec = pytree.treespec_loads(call_spec[0])
|
||||
out_spec = pytree.treespec_loads(call_spec[1])
|
||||
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
|
||||
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
|
||||
flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined]
|
||||
flat_outputs = self.loader.boxed_run(flat_inputs)
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
def get_metadata(self) -> dict[str, str]:
|
||||
return self.loader.get_metadata() # type: ignore[attr-defined]
|
||||
return self.loader.get_metadata()
|
||||
|
||||
def load_constants(
|
||||
self,
|
||||
@ -394,18 +393,18 @@ class AOTICompiledModel:
|
||||
check_full_update: Whether to add check to see if all the constants
|
||||
are updated and have values.
|
||||
"""
|
||||
self.loader.load_constants( # type: ignore[attr-defined]
|
||||
self.loader.load_constants(
|
||||
constants_map, False, check_full_update, user_managed
|
||||
)
|
||||
|
||||
def get_constant_fqns(self) -> list[str]:
|
||||
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
||||
return self.loader.get_constant_fqns()
|
||||
|
||||
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
|
||||
logger.warning(
|
||||
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
|
||||
)
|
||||
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
|
||||
return AOTICompiledModel(self.loader)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -531,7 +530,7 @@ def load_pt2(
|
||||
extra_files = _load_extra_files(archive_reader, file_names)
|
||||
|
||||
# Get a list of AOTI model names
|
||||
aoti_model_names = set()
|
||||
aoti_model_names: set[str] = set()
|
||||
for file in file_names:
|
||||
if file.startswith(AOTINDUCTOR_DIR):
|
||||
file = file[len(AOTINDUCTOR_DIR) :] # remove data/aotinductor/ prefix
|
||||
@ -540,33 +539,35 @@ def load_pt2(
|
||||
] # split "model_name/...cpp" into "model_name"
|
||||
aoti_model_names.add(model_name)
|
||||
|
||||
if isinstance(f, (io.IOBase, IO)) and len(aoti_model_names) > 0:
|
||||
# Workaround for AOTIModelPackageLoader not reading buffers
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as tf:
|
||||
f.seek(0)
|
||||
tf.write(f.read())
|
||||
f.seek(0)
|
||||
logger.debug("Writing buffer to tmp file located at %s.", tf.name)
|
||||
|
||||
aoti_runners = {
|
||||
model_name: AOTICompiledModel(
|
||||
torch._C._aoti.AOTIModelPackageLoader(
|
||||
tf.name,
|
||||
model_name,
|
||||
run_single_threaded,
|
||||
num_runners,
|
||||
device_index,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
for model_name in aoti_model_names
|
||||
}
|
||||
if isinstance(f, (io.IOBase, IO)):
|
||||
if len(aoti_model_names) > 0:
|
||||
# Workaround for AOTIModelPackageLoader not reading buffers
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as tf:
|
||||
f.seek(0)
|
||||
tf.write(f.read())
|
||||
f.seek(0)
|
||||
logger.debug("Writing buffer to tmp file located at %s.", tf.name)
|
||||
|
||||
aoti_runners = {
|
||||
model_name: AOTICompiledModel(
|
||||
torch._C._aoti.AOTIModelPackageLoader(
|
||||
tf.name,
|
||||
model_name,
|
||||
run_single_threaded,
|
||||
num_runners,
|
||||
device_index,
|
||||
)
|
||||
)
|
||||
for model_name in aoti_model_names
|
||||
}
|
||||
else:
|
||||
aoti_runners = {}
|
||||
else:
|
||||
aoti_runners = {
|
||||
model_name: AOTICompiledModel(
|
||||
torch._C._aoti.AOTIModelPackageLoader(
|
||||
f, model_name, run_single_threaded, num_runners, device_index
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
for model_name in aoti_model_names
|
||||
}
|
||||
|
Reference in New Issue
Block a user