Compare commits

...

1 Commits

Author SHA1 Message Date
26f67ef050 Add an option to put store large mmap weights on disk (#164526)
Summary:
As title

In windows, we cannot modify the .dll to append weights at the end, the windows .dll loader will complain it's not a valid .dll file. So we store the weight blob as a separete file.

1. We add the following API which allows passing in a pointer to the weight blob and get the size of the weight blob.

```cpp
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
    AOTInductorModelContainerHandle container_handle,
    uint64_t* ret_size);

// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
    AOTInductorModelContainerHandle container_handle,
    const uint8_t* weight_blob_ptr);
```

2. We also add a method in ModelContainerRunner to load the weight:

If the runner see that there is a `.blob` file in the package, if will mmap the .blob file and use the content to load the constants.

3. We also add the `USE_MMAP_EXTERNAL` macro. When this macro is defined, the model expects to load the weights from external mmap'd weights.


Test Plan:
```
buck run mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_large_mmaped_weights_on_disk
```

Also tested for windows-cross compilation with 6542566585/demo/main_voxtral.cpp

```
Loaded model.dll
audio_encoder loaded
C:\Users\shangdiy\source\repos\torchnative\demo\token_embedding\data\aotinductor\model\model.wrapper.so
Loaded model.dll
token_embedding loaded
C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper.so
Loaded model.dll
Loading weights from C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper_weights.blob
text_decoder loaded
Load latency (ms):
  audio_encoder: 1011.234
    archive extraction: 0.000
    .so loading: 1011.197
  token_embedding: 525.773
    archive extraction: 0.000
    .so loading: 525.704
  text_decoder: 3324.130
    archive extraction: 0.000
    .so loading: 3323.979
Run latency (ms):
  audio_encoder: 285.958
    audio_encoder output: dtype=bfloat16, shape=[1, 1125, 3072], numel=3456000
  token_embedding: 6.676
    token_embedding output: dtype=bfloat16, shape=[1, 1138, 3072], numel=3495936
  text_decoder: 576.519
    text_decoder output: dtype=bfloat16, shape=[1, 1138, 131072], numel=149159936
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

Differential Revision: D84093310

Pulled By: yushangdi
2025-10-09 15:21:58 -07:00
16 changed files with 370 additions and 23 deletions

View File

@ -701,6 +701,24 @@ class AOTInductorTestsTemplate:
with config.patch({"aot_inductor.force_mmap_weights": True}):
self.check_model(Model(), example_inputs)
def test_large_mmaped_weights_on_disk(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 250112)
def forward(self, x, y):
return x + self.linear(y)
example_inputs = (
torch.randn(1, 250112, device=self.device),
torch.randn(1, 512, device=self.device),
)
with config.patch(
{"aot_inductor.package_constants_on_disk_format": "binary_blob"}
):
self.check_model(Model(), example_inputs)
def test_with_offset(self):
class Model(torch.nn.Module):
def __init__(self, device):
@ -5891,7 +5909,7 @@ class AOTInductorTestsTemplate:
{
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package_constants_on_disk_format": "pickle_weights",
"aot_inductor.package": True,
}
),
@ -7431,6 +7449,14 @@ class TestAOTInductorConfig(TestCase):
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)
def test_compile_standalone_cross_compile_windows_package_format(self):
patches = {
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.package_constants_in_so": True,
}
with self.assertRaises(RuntimeError):
maybe_aoti_standalone_config(patches)
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

View File

@ -950,7 +950,7 @@ class TestAOTInductorPackage(TestCase):
"aot_inductor.package_cpp_only": self.package_cpp_only,
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package_constants_on_disk_format": "pickle_weights",
}
class Bar(torch.nn.Module):
@ -1034,7 +1034,7 @@ class TestAOTInductorPackage(TestCase):
"aot_inductor.package_cpp_only": self.package_cpp_only,
"always_keep_tensor_constants": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
"aot_inductor.package_constants_on_disk_format": "pickle_weights",
}
# linear.weight's node name is linear_weight.

View File

@ -48,13 +48,17 @@ class TestAOTInductorWindowsCrossCompilation(TestCase):
"aot_inductor.model_name_for_generated_files": "model",
"aot_inductor.cross_target_platform": "windows",
"aot_inductor.link_libtorch": False,
# TODO: need to add aoti_shim_library_path for CI
"aot_inductor.aoti_shim_library": "executorch",
# no fallback ops
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON,CPP",
"max_autotune_conv_backends": "TRITON,CPP",
"aot_inductor.embed_kernel_binary": True,
# simplify things for now
"aot_inductor.precompile_headers": False,
"aot_inductor.package_constants_on_disk_format": "binary_blob",
"aot_inductor.package_constants_in_so": False,
},
)

View File

@ -94,6 +94,7 @@ from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
clear_on_fresh_cache,
determine_aoti_mmap_flags,
is_linux,
is_windows,
)
@ -2175,7 +2176,10 @@ end
raw_bytes = bytes(raw_array.contents)
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
if config.aot_inductor.package_constants_in_so:
if (
config.aot_inductor.package_constants_in_so
or config.aot_inductor.package_constants_on_disk_format == "binary_blob"
):
serialized_weights = b"".join(
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
for name in graph.constants.keys()
@ -2184,7 +2188,7 @@ end
else:
serialized_weights = b""
if config.aot_inductor.package_constants_on_disk:
if config.aot_inductor.package_constants_on_disk_format == "pickle_weights":
# We need to return a storage key here because the original value tensor might be a clone
weights_dict = Weights(
{
@ -2200,15 +2204,27 @@ end
consts_size = len(serialized_weights)
# TODO: Fix mmap weights with cuda
use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
if config.aot_inductor.force_mmap_weights:
use_mmap_weights = True
use_external_weights, use_mmap_weights = determine_aoti_mmap_flags(
consts_size
)
if use_external_weights and use_mmap_weights:
# Should never reach here, just a check for sanity
raise RuntimeError(
"use_external_weights and use_mmap_weights cannot both be True."
)
external_weights_path = None
if use_external_weights:
external_weights_filename = f"{wrapper_path_operator.stem}_weights.blob"
external_weights_path = str(
wrapper_path_operator.with_name(external_weights_filename)
)
compile_command: dict[str, Any] = {
"aot_mode": graph.aot_mode,
"device_type": device_type,
"use_mmap_weights": use_mmap_weights,
"use_mmap_weights_external": use_external_weights,
"use_relative_path": use_relative_path,
"vec_isa": picked_vec_isa,
}
@ -2287,7 +2303,15 @@ end
if not use_mmap_weights:
aot_constants = serialized_weights
magic_number = 0
if use_external_weights:
aot_constants = struct.pack("q", consts_size)
assert external_weights_path is not None
# For external weights, write weights to separate file and embed minimal placeholder
with open(external_weights_path, "wb") as f_weights:
f_weights.write(serialized_weights)
generated_files.append(external_weights_path)
else:
# we'll append weights binary to the end of .so file and mmap it when loading
magic_number = cast(
int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item()
)
@ -2468,6 +2492,10 @@ end
os.remove(o_file)
if use_mmap_weights:
if config.aot_inductor.cross_target_platform == "windows":
raise RuntimeError(
"when cross_target_platform is windows, use_mmap_weights should not be true."
)
def get_page_size() -> int:
# Don't use resource.getpagesize() on Windows, as it is a Unix specific package

View File

@ -66,6 +66,7 @@ AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
size_t num_models,
const char* device_str,
const char* cubin_dir) {
if (num_models == 0) {
std::cerr << "Error: num_models must be positive, but got 0\n";
return AOTI_RUNTIME_FAILURE;
@ -82,6 +83,7 @@ AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(
})
}
AOTIRuntimeError AOTInductorModelContainerDelete(
AOTInductorModelContainerHandle container_handle) {
CONVERT_EXCEPTION_TO_ERROR_CODE({
@ -460,4 +462,27 @@ AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
})
}
AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
AOTInductorModelContainerHandle container_handle,
uint64_t* ret_size) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_size = container->constant_blob_size(); })
}
// Load weights from a single blob in weight_blob_ptr
AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr){
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{container->update_constants_from_blob(weight_blob_ptr); })
}
} // extern "C"

View File

@ -767,7 +767,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
num_outputs = len(V.graph.graph_outputs)
num_constants = len(V.graph.constants)
include_weights = (
"true" if config.aot_inductor.package_constants_in_so else "false"
"true"
if config.aot_inductor.package_constants_in_so
and config.aot_inductor.package_constants_on_disk_format != "binary_blob"
else "false"
)
self.prefix.splice(
f"""

View File

@ -1583,10 +1583,22 @@ class aot_inductor:
)
# Experimental. Flag to control whether to include weight in .so
# Not supported for cross_target_platform="windows".
package_constants_in_so: bool = True
# Experimental. Flag to control whether to package weight separately on disk
package_constants_on_disk: bool = False
# Experimental. Flag to control whether to package weight separately on disk and which
# format to package it in.
# Options:
# None:
# Do not package weight separately on disk.
# "pickle_weights":
# Each weight is pickled and stored separately in data/weights. We also store the
# FQN names of each weight in a weights_config.json in each model's data/aot_inductor/model folder.
# Can only be load back from python using torch._inductor.aoti_load_package API now.
# "binary_blob":
# Stores all weights in a single binary blob in data/aot_inductor/model folder for each model.
# This option and config.aot_inductor.force_mmap_weights cannot both be True
package_constants_on_disk_format: Optional[str] = None
# Experimental. Controls automatic precompiling of common AOTI include files.
precompile_headers: bool = not is_fbcode()

View File

@ -1385,10 +1385,19 @@ def _get_libstdcxx_args() -> tuple[list[str], list[str]]:
return lib_dir_paths, libs
def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
def get_mmap_self_macro(
use_mmap_weights: bool, use_mmap_weights_external: bool
) -> list[str]:
macros = []
if use_mmap_weights and use_mmap_weights_external:
raise RuntimeError(
"Only one of use_mmap_weights and use_mmap_weights_external should be true"
)
if use_mmap_weights:
macros.append(" USE_MMAP_SELF")
elif use_mmap_weights_external:
macros.append(" USE_MMAP_EXTERNAL")
return macros
@ -1408,6 +1417,7 @@ def get_cpp_torch_options(
aot_mode: bool,
use_relative_path: bool,
use_mmap_weights: bool,
use_mmap_weights_external: bool,
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
"""
This function is used to get the build args of torch related build options.
@ -1456,7 +1466,7 @@ def get_cpp_torch_options(
fb_macro_passthrough_args = _use_fb_internal_macros()
mmap_self_macros = get_mmap_self_macro(use_mmap_weights)
mmap_self_macros = get_mmap_self_macro(use_mmap_weights, use_mmap_weights_external)
caching_allocator_macros = get_caching_allocator_macro()
definitions = (
@ -1513,6 +1523,7 @@ class CppTorchOptions(CppOptions):
compile_only: bool = False,
use_relative_path: bool = False,
use_mmap_weights: bool = False,
use_mmap_weights_external: bool = False,
shared: bool = True,
extra_flags: Sequence[str] = (),
compiler: str = "",
@ -1548,6 +1559,7 @@ class CppTorchOptions(CppOptions):
aot_mode=aot_mode,
use_relative_path=use_relative_path,
use_mmap_weights=use_mmap_weights,
use_mmap_weights_external=use_mmap_weights_external,
)
_append_list(self._definitions, torch_definitions)
@ -1725,6 +1737,7 @@ class CppTorchDeviceOptions(CppTorchOptions):
compile_only: bool = False,
use_relative_path: bool = False,
use_mmap_weights: bool = False,
use_mmap_weights_external: bool = False,
shared: bool = True,
extra_flags: Sequence[str] = (),
min_optimize: bool = False,
@ -1738,6 +1751,7 @@ class CppTorchDeviceOptions(CppTorchOptions):
compile_only=compile_only,
use_relative_path=use_relative_path,
use_mmap_weights=use_mmap_weights,
use_mmap_weights_external=use_mmap_weights_external,
extra_flags=extra_flags,
min_optimize=min_optimize,
precompiling=precompiling,

View File

@ -3624,9 +3624,66 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An
)
force_patch_config(config_patches, "aot_inductor.dynamic_linkage", False)
cross_target_platform = config_patches.get(
"aot_inductor.cross_target_platform",
config.aot_inductor.cross_target_platform,
)
package_constants_in_so = config_patches.get(
"aot_inductor.package_constants_in_so",
config.aot_inductor.package_constants_in_so,
)
if cross_target_platform == "windows" and package_constants_in_so:
raise RuntimeError(
"config.aot_inductor.package_constants_in_so is not supported for windows cross-compilation. "
"Please use config.aot_inductor.package_constants_on_disk_format = binary_blob."
)
return config_patches
def determine_aoti_mmap_flags(consts_size: int) -> tuple[bool, bool]:
"""
Decide whether we should mmap weights, and whether to store the weights with .so.
If force_mmap_weights or package_constants_on_disk_format == "binary_blob" configs are set, respect the config.
Returns tuple (use_external_weights, use_mmap_weights).
"""
if (
config.aot_inductor.force_mmap_weights
and config.aot_inductor.package_constants_on_disk_format == "binary_blob"
):
raise RuntimeError(
"config.aot_inductor.package_constants_on_disk_format = binary_blob and "
"config.aot_inductor.force_mmap_weights cannot both be True."
)
if config.aot_inductor.force_mmap_weights:
if config.aot_inductor.cross_target_platform == "windows":
raise RuntimeError(
"when cross_target_platform is windows, use_mmap_weights should not be true."
)
use_mmap_weights = True
use_external_weights = False
return use_external_weights, use_mmap_weights
if config.aot_inductor.package_constants_on_disk_format == "binary_blob":
use_external_weights = True
use_mmap_weights = False
return use_external_weights, use_mmap_weights
if consts_size <= 2_000_000_000:
return False, False
use_external_weights = False
use_mmap_weights = not config.is_fbcode()
return use_external_weights, use_mmap_weights
def is_valid_aoti_model_name() -> bool:
"""
Validates if a model name is suitable for use in code generation.

View File

@ -718,6 +718,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
std::string so_filename;
std::string cpp_filename;
std::string weight_blob_filename;
std::vector<std::string> obj_filenames;
std::string model_directory = normalize_path_separator(
file_prefix + "data" + k_separator + "aotinductor" + k_separator +
@ -782,6 +783,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
obj_filenames.push_back(output_file_path);
} else if (filename_extension == extension_file_ext()) {
so_filename = output_file_path;
} else if (filename_extension == ".blob") {
weight_blob_filename = output_file_path;
}
}
}
@ -838,6 +841,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
runner_ = registered_aoti_runner[device_key](
so_path, num_runners, device.str(), cubin_dir, run_single_threaded);
if (weight_blob_filename != "") {
runner_->update_constant_buffer_from_blob(weight_blob_filename);
}
}
AOTIModelPackageLoader::~AOTIModelPackageLoader() {

View File

@ -7,6 +7,19 @@
#include <c10/util/FileSystem.h>
#include <fcntl.h>
#ifdef _WIN32
#include <errno.h>
#include <io.h>
#include <sys/stat.h>
#include <windows.h>
#include <functional> // std::function
#else // !_WIN32
#include <dlfcn.h>
#include <sys/mman.h>
#include <unistd.h>
#endif // _WIN32
namespace torch::inductor {
AOTIModelContainerRunner::AOTIModelContainerRunner(
@ -88,6 +101,12 @@ consider rebuild your model with the latest AOTInductor.");
TRY_LOAD_SYMBOL(
update_user_managed_constant_buffer_func_,
"AOTInductorModelContainerUpdateUserManagedConstantBuffer")
TRY_LOAD_SYMBOL(
get_constants_blob_size_func_,
"AOTInductorModelContainerGetConstantsBlobSize")
TRY_LOAD_SYMBOL(
update_constants_from_blob_func_,
"AOTInductorModelUpdateConstantsFromBlob")
#undef TRY_LOAD_SYMBOL
// Hack to find the json file name from the model so file
@ -251,6 +270,81 @@ void AOTIModelContainerRunner::update_constant_buffer(
}
}
void AOTIModelContainerRunner::update_constant_buffer_from_blob(
const std::string& weights_path) {
uint64_t weights_size;
AOTI_RUNTIME_ERROR_CODE_CHECK(
get_constants_blob_size_func_(container_handle_, &weights_size));
#ifdef _WIN32
// Proper Windows file mapping implementation
HANDLE hFile = CreateFileA(
weights_path.c_str(),
GENERIC_READ,
FILE_SHARE_READ,
NULL,
OPEN_EXISTING,
FILE_ATTRIBUTE_NORMAL,
NULL);
if (hFile == INVALID_HANDLE_VALUE) {
throw std::runtime_error(
"Failed to open external weights file: " + weights_path);
}
// Get actual file size for validation
LARGE_INTEGER fileSize;
if (!GetFileSizeEx(hFile, &fileSize)) {
CloseHandle(hFile);
throw std::runtime_error("Failed to get file size");
}
if (static_cast<uint64_t>(fileSize.QuadPart) < weights_size) {
CloseHandle(hFile);
throw std::runtime_error("File size smaller than expected weights size");
}
HANDLE hMapping = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
CloseHandle(hFile); // Close file handle, keep mapping handle
if (hMapping == NULL) {
throw std::runtime_error("CreateFileMapping failed");
}
uint8_t* ptr = static_cast<uint8_t*>(
MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, weights_size));
if (ptr == NULL) {
CloseHandle(hMapping);
throw std::runtime_error("MapViewOfFile failed");
}
#else
// Unix/Linux implementation
int fd = open(weights_path.c_str(), O_RDONLY);
TORCH_CHECK(fd >= 0, "Failed to open external weights file: " + weights_path);
uint8_t* ptr = static_cast<uint8_t*>(
mmap(NULL, weights_size, PROT_READ, MAP_PRIVATE, fd, 0));
close(fd);
TORCH_CHECK(ptr != MAP_FAILED, "mmap() failed");
#endif
AOTI_RUNTIME_ERROR_CODE_CHECK(
update_constants_from_blob_func_(container_handle_, ptr));
// After update_constants_from_blob_func_ returns, the model has copied
// all the data from the mmap'd memory to its own internal storage,
// so we can safely unmap the memory now.
#ifdef _WIN32
UnmapViewOfFile(ptr);
CloseHandle(hMapping);
#else
munmap(ptr, weights_size);
#endif
}
void AOTIModelContainerRunner::update_inactive_constant_buffer(
const TensorConstantMap& const_map) {
AOTI_RUNTIME_ERROR_CODE_CHECK(update_inactive_constant_buffer_func_(

View File

@ -55,6 +55,7 @@ class TORCH_API AOTIModelContainerRunner {
AOTInductorStreamHandle cuda_stream_handle = nullptr);
void swap_constant_buffer();
void free_inactive_constant_buffer();
void update_constant_buffer_from_blob(const std::string& weights_path);
std::vector<std::string> get_call_spec();
@ -99,6 +100,10 @@ class TORCH_API AOTIModelContainerRunner {
decltype(&AOTInductorModelContainerFreeInactiveConstantBuffer)
free_inactive_constant_buffer_func_{nullptr};
decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr};
decltype(&AOTInductorModelContainerGetConstantsBlobSize)
get_constants_blob_size_func_{nullptr};
decltype(&AOTInductorModelUpdateConstantsFromBlob)
update_constants_from_blob_func_{nullptr};
AOTInductorModelContainerHandle container_handle_ = nullptr;

View File

@ -51,7 +51,11 @@ void initAOTIRunnerBindings(PyObject* module) {
&AOTIModelContainerRunnerCpu::swap_constant_buffer)
.def(
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerCpu::free_inactive_constant_buffer);
&AOTIModelContainerRunnerCpu::free_inactive_constant_buffer)
.def(
"update_constant_buffer_from_blob",
&AOTIModelContainerRunnerCpu::update_constant_buffer_from_blob,
py::arg("weights_path"));
#ifdef USE_CUDA
py::class_<AOTIModelContainerRunnerCuda>(m, "AOTIModelContainerRunnerCuda")
@ -91,7 +95,11 @@ void initAOTIRunnerBindings(PyObject* module) {
&AOTIModelContainerRunnerCuda::swap_constant_buffer)
.def(
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerCuda::free_inactive_constant_buffer);
&AOTIModelContainerRunnerCuda::free_inactive_constant_buffer)
.def(
"update_constant_buffer_from_blob",
&AOTIModelContainerRunnerCuda::update_constant_buffer_from_blob,
py::arg("weights_path"));
#endif
#ifdef USE_XPU
py::class_<AOTIModelContainerRunnerXpu>(m, "AOTIModelContainerRunnerXpu")
@ -131,8 +139,11 @@ void initAOTIRunnerBindings(PyObject* module) {
&AOTIModelContainerRunnerXpu::swap_constant_buffer)
.def(
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerXpu::free_inactive_constant_buffer);
&AOTIModelContainerRunnerXpu::free_inactive_constant_buffer)
.def(
"update_constant_buffer_from_blob",
&AOTIModelContainerRunnerXpu::update_constant_buffer_from_blob,
py::arg("weights_path"));
#endif
#if defined(USE_MPS) && defined(__APPLE__) && \
!(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
@ -167,8 +178,11 @@ void initAOTIRunnerBindings(PyObject* module) {
&AOTIModelContainerRunnerMps::swap_constant_buffer)
.def(
"free_inactive_constant_buffer",
&AOTIModelContainerRunnerMps::free_inactive_constant_buffer);
&AOTIModelContainerRunnerMps::free_inactive_constant_buffer)
.def(
"update_constant_buffer_from_blob",
&AOTIModelContainerRunnerMps::update_constant_buffer_from_blob,
py::arg("weights_path"));
#endif
m.def(

View File

@ -242,6 +242,16 @@ AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsMap(
AOTInductorModelHandle model_handle,
AOTInductorConstantMapHandle constant_map_handle);
// Get the size of the constant blob
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
AOTInductorModelContainerHandle container_handle,
uint64_t* ret_size);
// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr);
// Delete an AOTInductorModel created by AOTInductorModelCreate.
AOTI_API AOTIRuntimeError
AOTInductorModelDelete(AOTInductorModelHandle model_handle);

View File

@ -581,7 +581,14 @@ class AOTInductorModelBase {
return folded_constants;
}
void load_constants() {
void update_constants_from_blob(const uint8_t* weight_blob_ptr) {
#if defined(USE_MMAP_EXTERNAL)
user_managed_mmap = const_cast<uint8_t*>(weight_blob_ptr);
load_constants(true);
#endif
}
void load_constants(bool force = false) {
size_t num_constants = this->num_constants();
size_t num_folded_constants = this->num_folded_constants();
constants_map_->reserve(num_constants);
@ -590,7 +597,7 @@ class AOTInductorModelBase {
num_constants - num_folded_constants);
size_t blob_size = 0;
compute_constant_blob(blob_size, constants_internal_offset);
if (!include_weights) {
if (!force && !include_weights) {
return;
}
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
@ -817,6 +824,17 @@ class AOTInductorModelBase {
return out_spec_.c_str();
}
uint64_t constant_blob_size() const {
#if defined(USE_MMAP_SELF) || defined(USE_MMAP_EXTERNAL)
const uint64_t weights_size =
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
return weights_size;
#else
throw std::runtime_error{
"constant blob size is only available for mmap'd weights"};
#endif
}
void update_constants_array_from_map() {
if (!constants_map_) {
throw std::runtime_error{
@ -903,6 +921,15 @@ class AOTInductorModelBase {
protected:
uint8_t* _get_constants_start() {
#if defined(USE_MMAP_EXTERNAL)
if (!user_managed_mmap) {
throw std::runtime_error{
"Constants are not mmap'd. Use AOTInductorModelUpdateConstantsBlob to initialize the constants first."};
}
// Mapped memory for weights
return user_managed_mmap;
#endif
#ifndef USE_MMAP_SELF
// NOLINTNEXTLINE(*const-cast*)
return const_cast<uint8_t*>(_binary_constants_bin_start);
@ -942,6 +969,7 @@ class AOTInductorModelBase {
return self_mmap;
#endif
}
struct ParamInfo {
const char* name = nullptr;
};
@ -973,10 +1001,16 @@ class AOTInductorModelBase {
// Holds the blob storage for constants' at::Tensor.
RAIIDataPtr constant_blob_;
#ifdef USE_MMAP_SELF
#if defined(USE_MMAP_SELF)
// Mapped memory for weights
uint8_t* self_mmap = NULL;
#endif
#if defined(USE_MMAP_EXTERNAL)
// Mapped memory for weights
uint8_t* user_managed_mmap = NULL;
#endif
// A directory with CUDA binary files, e.g. compiled kernels, etc.
const std::optional<std::string> cubin_dir_;

View File

@ -255,6 +255,20 @@ class AOTInductorModelContainer {
return models_[0]->constant_dtype(static_cast<int64_t>(idx));
}
uint64_t constant_blob_size() const {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->constant_blob_size();
}
void update_constants_from_blob(const uint8_t* weight_blob_ptr) {
if (this->num_models() == 0) {
throw std::runtime_error("No available models in container!");
}
return models_[0]->update_constants_from_blob(weight_blob_ptr);
}
void run_const_fold(
bool inactive_buffer,
DeviceStreamType stream,