Compare commits

...

1 Commits

Author SHA1 Message Date
22b9f783c5 weight sharing for cpp_package_only 2025-07-31 11:14:13 -07:00
6 changed files with 424 additions and 203 deletions

View File

@ -155,6 +155,9 @@ class TestAOTInductorPackage(TestCase):
env=custom_env,
check=True,
)
# print(base_dir)
# print(base_dir +"/package/data/aotinductor/Model__default/aot_consts_mapping.h")
# breakpoint()
subprocess.run(["make"], cwd=build_path, check=True)
result = subprocess.run(
@ -430,12 +433,20 @@ class TestAOTInductorPackage(TestCase):
self.check_package_cpp_only()
class Model1(torch.nn.Module):
def __init__(self, a):
super().__init__()
self.a = a
def forward(self, x, y):
return x + y
return x + y + self.a
class Model2(torch.nn.Module):
def __init__(self, a):
super().__init__()
self.a = a
def forward(self, x, y):
return x - y
return x - y + self.a
def default(*args, **kwargs):
return None
@ -445,15 +456,17 @@ class TestAOTInductorPackage(TestCase):
torch.ones(3, 3).to(self.device),
)
a = torch.ones(3, 3).to(self.device)
package = _ExportPackage()
m1 = Model1()
m2 = Model2()
m1 = Model1(a)
m2 = Model2(a)
exporter1 = package._exporter("Plus", m1)._define_overload("default", default)
exporter2 = package._exporter("Minus", m2)._define_overload("default", default)
exporter1(*example_inputs)
exporter2(*example_inputs)
for package_example_inputs in [True, False]:
for package_example_inputs in [True]: # , False
with (
tempfile.TemporaryDirectory() as tmp_dir,
):
@ -467,14 +480,14 @@ class TestAOTInductorPackage(TestCase):
if self.device == GPU_TYPE:
self.assertEqual(
result.stdout,
"output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n"
" 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n",
"output_tensor1\n 3 3 3\n 3 3 3\n 3 3 3\n[ CUDAFloatType{3,3} ]\noutput_tensor2\n 1 1 1\n"
" 1 1 1\n 1 1 1\n[ CUDAFloatType{3,3} ]\n",
)
else:
self.assertEqual(
result.stdout,
"output_tensor1\n 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2\n 0 0 0\n"
" 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n",
"output_tensor1\n 3 3 3\n 3 3 3\n 3 3 3\n[ CPUFloatType{3,3} ]\noutput_tensor2\n 1 1 1\n"
" 1 1 1\n 1 1 1\n[ CPUFloatType{3,3} ]\n",
)
@unittest.skipIf(
@ -483,7 +496,6 @@ class TestAOTInductorPackage(TestCase):
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # doesn't support multi-arch binary
@skipIfXpu # doesn't support multi-arch binary
@torch._inductor.config.patch("test_configs.use_libtorch", True)
def test_compile_with_exporter_weights(self):
self.check_package_cpp_only()
@ -512,7 +524,7 @@ class TestAOTInductorPackage(TestCase):
tempfile.TemporaryDirectory() as tmp_dir,
):
package._compiled_and_package(
tmp_dir + "/package.pt2", True, package_example_inputs
tmp_dir + "/package.pt2", True, package_example_inputs, weight_share=True
)
# Test compiling generated files

View File

@ -87,6 +87,7 @@ from torch._inductor.runtime.compile_tasks import _reload_python_module
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
IndentedBuffer,
clear_on_fresh_cache,
is_linux,
is_windows,
@ -1621,6 +1622,206 @@ class CudaKernelParamCache:
return cls.cache.keys()
def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes:
def _pad_to_alignment(raw_bytes: bytes) -> bytes:
padded_bytes = raw_bytes.ljust(
(len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES,
b"\x00",
)
return padded_bytes
# This serializes the tensor's untyped_storage to bytes by accessing
# the raw data of the underlying structure.
import ctypes
if t.numel() == 0:
return b""
if t.is_mkldnn:
data_ptr = torch.ops.mkldnn.data_ptr(t)
nbytes = torch.ops.mkldnn._nbytes(t)
else:
t_cpu = t.untyped_storage().cpu()
data_ptr = t_cpu.data_ptr()
nbytes = t_cpu.nbytes()
raw_array = ctypes.cast(
data_ptr,
ctypes.POINTER(ctypes.c_ubyte * nbytes),
)
raw_bytes = bytes(raw_array.contents)
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
def _compile_consts(consts: bytes, platform: str, mutating: bool, device_type: str, aot_mode: bool, specified_sub_dir: Path | str, use_relative_path: bool) -> str:
# Load from aot_inductor, and update the value on demand.
use_asm_build: bool = config.aot_inductor.use_consts_asm_build
section_attr = ""
if platform == "linux":
if mutating:
# .data section is between .text and .bss. When the size of .data is large,
# during the linking, the relocation of .text against .bss may overflow.
# Rename it to .ldata so that it won't be in between the .text and .bss section
if len(consts) > 2_000_000_000:
raise ValueError(
"Models with buffer mutation included doesn't support constants greater than 2GB!"
)
section_attr = '.ldata, "aw"'
else:
section_attr = '.lrodata, "a"'
symbol_prefix = ""
elif platform == "darwin":
section_attr = "__DATA,__data"
symbol_prefix = "_"
elif platform == "win32":
symbol_prefix = ""
# ASM build is not supported on Windows, force use CPP build.
use_asm_build = False
else:
raise RuntimeError(f"Unsupported platform: {platform}")
# Intel compiler failed to compile this manually constructed assembly file.
# Switch XPU to use consts cpp build.
if device_type == "xpu":
use_asm_build = False
is_large_consts = len(consts) > 1024
def format_consts_to_asm(
consts: bytes,
align_bytes: int,
symbol_prefix: str,
is_large_consts: bool,
) -> tuple[str, str]:
consts_asm = f"\t.section\t{section_attr}\n"
consts_asm += f"\t.balign {align_bytes}\n"
consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n"
consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n"
if not is_large_consts:
for c in consts:
consts_asm += f"\t.byte {c}\n"
# Add one element even if constants are empty
# Otherwise assembler will not put them in data section
if not consts:
consts_asm += "\t.space 1\n"
else:
consts_asm += "\t.quad 0x1234567899abcdef\n"
consts_asm += f"\t.space {len(consts) - 8}\n"
consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n"
consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n"
return consts_asm, "weights.S"
# Use c++ to convert consts to object file can support more compilers, such as msvc and icx.
def format_consts_to_cpp(
consts: bytes, align_bytes: int, symbol_prefix: str
) -> tuple[str, str]:
consts_size = len(consts)
asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\
#define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\
#else\t\n\
#define ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n\
#endif\t\n\
\t\n\
ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
const_cpp = asan_attr
const_cpp += f"alignas({align_bytes}) extern "
const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n"
count_bytes = 0
for c in consts:
const_cpp += f"{c}, "
count_bytes = count_bytes + 1
if count_bytes % 16 == 0:
const_cpp += "\t\n"
const_cpp += "};\t\n"
const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n"
return const_cpp, "weights.cpp"
if use_asm_build:
consts_code, code_ext = format_consts_to_asm(
consts, ALIGN_BYTES, symbol_prefix, is_large_consts
)
else:
consts_code, code_ext = format_consts_to_cpp(
consts, ALIGN_BYTES, symbol_prefix
)
_, consts_s = write(
consts_code,
code_ext,
specified_dir=str(specified_sub_dir),
key=config.aot_inductor.model_name_for_generated_files,
)
consts_s = Path(consts_s)
object_build_options = CppTorchDeviceOptions(
device_type=device_type,
aot_mode=aot_mode,
compile_only=True,
use_relative_path=use_relative_path,
)
object_builder = CppBuilder(
name=str(consts_s.stem),
sources=str(consts_s),
output_dir=str(consts_s.parent),
BuildOption=object_build_options,
)
consts_o = object_builder.get_target_file_path()
object_builder.build()
if is_large_consts and use_asm_build:
with open(consts_o, "r+b") as f:
f.seek(0)
hdr = f.read(1024)
# Search for magic number and write the actual data over it
start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12")
assert start_idx != -1
f.seek(start_idx)
pos = 0
while pos < len(consts):
rc = f.write(consts[pos:])
pos += rc
# Remove the .S file to save space
os.remove(consts_s)
return consts_o
def generate_aot_consts_mapping(weights_config: dict[str, Any], model_name: str, total_size: int) -> str:
ib = IndentedBuffer()
ib.writelines(
[
'#include <unordered_map>',
'#include <string>\n',
'namespace torch::aot_inductor {',
f'const std::unordered_map<std::string, int>& get_{model_name}_consts_mapping() {{',
' static const std::unordered_map<std::string, int> index = {'
]
)
for weight_name, val in weights_config.items():
start_offset, size, shape, stride, tensor_offset = val
ib.writelines([f' {{"{weight_name}", {start_offset}}},'])
ib.writelines(
[
" };",
" return index;",
"}",
"",
f"int get_{model_name}_total_const_bytes() {{",
f" return {total_size};",
"}",
"}",
" // namespace torch::aot_inductor",
]
)
return ib.getvalue()
class AotCodeCompiler:
"""
Compile AOT Inductor generated code.
@ -1789,136 +1990,6 @@ class AotCodeCompiler:
specified_sub_dir.mkdir(exist_ok=True)
cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt")
def _compile_consts(consts: bytes, platform: str) -> str:
# Load from aot_inductor, and update the value on demand.
use_asm_build: bool = config.aot_inductor.use_consts_asm_build
if platform == "linux":
if graph.mutated_buffers & OrderedSet(graph.constants.keys()):
# .data section is between .text and .bss. When the size of .data is large,
# during the linking, the relocation of .text against .bss may overflow.
# Rename it to .ldata so that it won't be in between the .text and .bss section
if len(consts) > 2_000_000_000:
raise ValueError(
"Models with buffer mutation included doesn't support constants greater than 2GB!"
)
section_attr = '.ldata, "aw"'
else:
section_attr = '.lrodata, "a"'
symbol_prefix = ""
elif platform == "darwin":
section_attr = "__DATA,__data"
symbol_prefix = "_"
elif platform == "win32":
symbol_prefix = ""
# ASM build is not supported on Windows, force use CPP build.
use_asm_build = False
else:
raise RuntimeError(f"Unsupported platform: {platform}")
# Intel compiler failed to compile this manually constructed assembly file.
# Switch XPU to use consts cpp build.
if device_type == "xpu":
use_asm_build = False
is_large_consts = len(consts) > 1024
def format_consts_to_asm(
consts: bytes,
align_bytes: int,
symbol_prefix: str,
is_large_consts: bool,
) -> tuple[str, str]:
consts_asm = f"\t.section\t{section_attr}\n"
consts_asm += f"\t.balign {align_bytes}\n"
consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n"
consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n"
if not is_large_consts:
for c in consts:
consts_asm += f"\t.byte {c}\n"
# Add one element even if constants are empty
# Otherwise assembler will not put them in data section
if not consts:
consts_asm += "\t.space 1\n"
else:
consts_asm += "\t.quad 0x1234567899abcdef\n"
consts_asm += f"\t.space {len(consts) - 8}\n"
consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n"
consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n"
return consts_asm, "S"
# Use c++ to convert consts to object file can support more compilers, such as msvc and icx.
def format_consts_to_cpp(
consts: bytes, align_bytes: int, symbol_prefix: str
) -> tuple[str, str]:
consts_size = len(consts)
asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\
#define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\
#else\t\n\
#define ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n\
#endif\t\n\
\t\n\
ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
const_cpp = asan_attr
const_cpp += f"alignas({align_bytes}) extern "
const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n"
count_bytes = 0
for c in consts:
const_cpp += f"{c}, "
count_bytes = count_bytes + 1
if count_bytes % 16 == 0:
const_cpp += "\t\n"
const_cpp += "};\t\n"
const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n"
return const_cpp, "cpp"
if use_asm_build:
consts_code, code_ext = format_consts_to_asm(
consts, ALIGN_BYTES, symbol_prefix, is_large_consts
)
else:
consts_code, code_ext = format_consts_to_cpp(
consts, ALIGN_BYTES, symbol_prefix
)
_, consts_s = write(
consts_code,
code_ext,
specified_dir=str(specified_sub_dir),
)
consts_s = Path(consts_s)
object_build_options = CppTorchDeviceOptions(
device_type=device_type,
aot_mode=graph.aot_mode,
compile_only=True,
use_relative_path=use_relative_path,
)
object_builder = CppBuilder(
name=str(consts_s.stem),
sources=str(consts_s),
output_dir=str(consts_s.parent),
BuildOption=object_build_options,
)
consts_o = object_builder.get_target_file_path()
object_builder.build()
if is_large_consts and use_asm_build:
with open(consts_o, "r+b") as f:
f.seek(0)
hdr = f.read(1024)
# Search for magic number and write the actual data over it
start_idx = hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12")
assert start_idx != -1
f.seek(start_idx)
pos = 0
while pos < len(consts):
rc = f.write(consts[pos:])
pos += rc
# Remove the .S file to save space
os.remove(consts_s)
return consts_o
from torch.utils._filelock import FileLock
@ -1977,36 +2048,6 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
if name not in graph.folded_constants
)
def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes:
def _pad_to_alignment(raw_bytes: bytes) -> bytes:
padded_bytes = raw_bytes.ljust(
(len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES,
b"\x00",
)
return padded_bytes
# This serializes the tensor's untyped_storage to bytes by accessing
# the raw data of the underlying structure.
import ctypes
if t.numel() == 0:
return b""
if t.is_mkldnn:
data_ptr = torch.ops.mkldnn.data_ptr(t)
nbytes = torch.ops.mkldnn._nbytes(t)
else:
t_cpu = t.untyped_storage().cpu()
data_ptr = t_cpu.data_ptr()
nbytes = t_cpu.nbytes()
raw_array = ctypes.cast(
data_ptr,
ctypes.POINTER(ctypes.c_ubyte * nbytes),
)
raw_bytes = bytes(raw_array.contents)
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)
if config.aot_inductor.package_constants_in_so:
serialized_weights = b"".join(
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
@ -2016,6 +2057,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
else:
serialized_weights = b""
use_mmap_weights = False
if config.aot_inductor.package_constants_on_disk:
# We need to return a storage key here because the original value tensor might be a clone
weights_dict = Weights(
@ -2029,13 +2071,13 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
}
)
generated_files.append(weights_dict)
else:
consts_size = len(serialized_weights)
consts_size = len(serialized_weights)
# TODO: Fix mmap weights with cuda
use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
if config.aot_inductor.force_mmap_weights:
use_mmap_weights = True
# TODO: Fix mmap weights with cuda
use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
if config.aot_inductor.force_mmap_weights:
use_mmap_weights = True
compile_command: dict[str, Any] = {
"aot_mode": graph.aot_mode,
@ -2125,7 +2167,8 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
)
aot_constants = struct.pack("qq", consts_size + 8, magic_number)
consts_o = _compile_consts(aot_constants, sys.platform)
mutating = len(graph.mutated_buffers & OrderedSet(graph.constants.keys())) > 0
consts_o = _compile_consts(aot_constants, sys.platform, mutating, device_type, graph.aot_mode, specified_sub_dir, use_relative_path)
custom_obj_idx = 0
# Note that custom_objs_config.json file is different from the model_constants_config.json file produced
# in package_sigmoid(). The keys in custom_objs_config.json directly correspond to the arg name in extern
@ -2269,7 +2312,7 @@ ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n"""
f_weights.write(struct.pack("q", magic_number))
generated_files.append(weight_file)
else:
elif not config.aot_inductor.package_constants_on_disk:
# TODO: unify to always use mmap_weights
generated_files.append(consts_o)
so_builder.save_src_to_cmake(cmake_path, consts_o)

View File

@ -113,6 +113,46 @@ namespace torch::aot_inductor {
using ConstantMap =
std::unordered_map<std::string, MaybeOwningAtenTensorHandle>;
inline RAIIDataPtr device_allocate(size_t blob_size){
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
return RAII_gpuMalloc(blob_size);
#else
return RAII_cpuMalloc(blob_size);
#endif
}
// Allocate memory for constants and copy from
// _binary_constants_bin_start to the allocated memory.
inline RAIIDataPtr load_constants_blob(size_t data_size, bool skip_copy = false) {
auto constant_blob = device_allocate(data_size);
uint8_t* internal_ptr = static_cast<uint8_t*>(constant_blob.get());
if (!skip_copy) {
#ifdef USE_XPU
sycl::queue* queue_ptr = nullptr;
aoti_torch_get_current_sycl_queue((void**)&queue_ptr);
queue_ptr
->memcpy(internal_ptr, _binary_constants_bin_start, data_size)
.wait();
#elif USE_CUDA
AOTI_RUNTIME_CUDA_CHECK(cudaMemcpy(
internal_ptr,
_binary_constants_bin_start,
data_size,
cudaMemcpyHostToDevice));
#elif USE_MPS
aoti_torch_mps_memcpy(
internal_ptr,
0,
0,
data_size,
_binary_constants_bin_start);
#else
memcpy(internal_ptr, _binary_constants_bin_start, data_size);
#endif
}
return constant_blob;
}
// valid device strs are: cpu, cuda, cuda:0, cuda:1, ...
// Update the list here if more devices are supported in the future
inline void parse_device_str(
@ -323,12 +363,7 @@ class AOTInductorModelBase {
if (!include_weights) {
return;
}
#if defined(USE_CUDA) || defined(USE_XPU) || defined(USE_MPS)
constant_blob_ = RAII_gpuMalloc(blob_size);
#else
constant_blob_ = RAII_cpuMalloc(blob_size);
#endif
device_allocate(blob_size);
size_t bytes_read = 0;
for (size_t i = 0; i < num_constants; i++) {
bool from_folded = this->constant_from_folded(i);
@ -382,6 +417,65 @@ class AOTInductorModelBase {
}
}
void load_dedup_constants(uint8_t* constants_ptr, const std::unordered_map<std::string, int>& offsets) {
if (include_weights) {
std::cout << "Please call load_constants() instead \n";
return;
}
size_t num_constants = this->num_constants();
size_t num_folded_constants = this->num_folded_constants();
constants_map_->reserve(num_constants);
for (size_t i = 0; i < num_constants; i++) {
bool from_folded = this->constant_from_folded(i);
if (from_folded) {
continue;
}
std::string name = this->constant_name(i);
std::string original_fqn = this->constant_original_fqn(i);
// storate tensor offset in blob
int storage_offset = offsets.at(original_fqn);
size_t data_size = this->constant_data_size(i);
uint8_t* internal_ptr = (data_size != 0)
? constants_ptr + storage_offset
: nullptr;
// Create at::Tensor from copied memory.
auto dtype = this->constant_dtype(i);
auto ndim = this->constant_ndim(i);
auto size = this->constant_shape(i);
auto stride = this->constant_stride(i);
#ifdef USE_MPS
auto offset = this->constant_offset(i) +
(constants_internal_offset[i] / aoti_torch_dtype_element_size(dtype));
#else
auto offset = this->constant_offset(i);
#endif
auto layout = this->constant_layout(i);
auto opaque_metadata_ptr = this->opaque_metadata(i);
auto opaque_metadata_size = this->opaque_metadata_size(i);
AtenTensorHandle tensor_handle = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2(
internal_ptr,
ndim,
size,
stride,
offset,
dtype,
device_type_,
device_idx_,
&tensor_handle,
layout,
opaque_metadata_ptr,
opaque_metadata_size));
constants_map_->emplace(std::move(name), tensor_handle);
}
if (constants_map_) {
this->update_constants_array_from_map();
}
}
RAIIDataPtr&& release_constant_blob() {
return std::move(constant_blob_);
}

View File

@ -355,14 +355,15 @@ class _ExportPackage:
f: torch.types.FileLike,
standalone: bool = False,
package_example_inputs: bool = False,
weight_share: bool = False,
) -> None:
options: dict[str, typing.Any] = {
"aot_inductor.package": True,
"aot_inductor.package_cpp_only": True,
"always_keep_tensor_constants": True,
# we'll change this back to False once we enable weight deduping for standalone mode
"aot_inductor.package_constants_in_so": standalone,
"aot_inductor.package_constants_in_so": not weight_share,
"aot_inductor.compile_standalone": standalone,
"aot_inductor.package_constants_on_disk": weight_share,
}
aoti_files_map = {}
model_names = []
@ -417,13 +418,13 @@ class _ExportPackage:
path = Path(base_directory) / f"{name}_input_{i}.pt"
torch.save(t, path)
cmake_file_str = _get_make_file(package_name, model_names, use_cuda)
cmake_file_str = _get_make_file(package_name, model_names, use_cuda, weight_share=standalone)
with open(Path(base_directory) / "CMakeLists.txt", "w") as file:
file.write(cmake_file_str)
main_file_str = _get_main_cpp_file(
package_name, model_names, use_cuda, example_inputs_map
package_name, model_names, use_cuda, example_inputs_map, weight_share=standalone
)
with open(Path(base_directory) / "main.cpp", "w") as file:
file.write(main_file_str)

View File

@ -13,6 +13,7 @@ def _get_main_cpp_file(
model_names: list[str],
cuda: bool,
example_inputs_map: typing.Optional[dict[str, int]],
weight_share: bool = False,
) -> str:
"""
Generates a main.cpp file for AOTInductor standalone models in the specified package.
@ -54,6 +55,9 @@ def _get_main_cpp_file(
ib.writeline(
f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"'
)
ib.writeline(
f'#include "{package_name}/data/aotinductor/{model_name}/aot_consts_mapping.h"'
)
ib.newline()
for model_name in model_names:
@ -69,6 +73,14 @@ def _get_main_cpp_file(
)
with ib.indent():
if weight_share:
ib.writelines([
f"size_t blob_size = torch::aot_inductor::get_{model_names[0]}_total_const_bytes();",
"auto constant_blob = torch::aot_inductor::load_constants_blob(blob_size);",
"uint8_t* constants_ptr = static_cast<uint8_t*>(constant_blob.get());",
])
ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";')
ib.writeline("try {")
@ -132,9 +144,13 @@ def _get_main_cpp_file(
f" std::move(constants_array{i + 1}),",
" device_str,",
f' "{package_name}/data/aotinductor/{model_name}/");',
f"model{i + 1}->load_constants();",
f'auto model{i + 1}_mapping = torch::aot_inductor::get_{model_name}_consts_mapping();',
]
)
if weight_share:
ib.writeline( f"model{i + 1}->load_dedup_constants(constants_ptr, model{i + 1}_mapping);")
else:
ib.writeline( f"model{i + 1}->load_constants();")
if example_inputs_map is not None:
ib.writeline("\n// Run the models")
@ -181,7 +197,7 @@ def _get_main_cpp_file(
return ib.getvalue()
def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str:
def _get_make_file(package_name: str, model_names: list[str], cuda: bool, weight_share: bool) -> str:
ib = IndentedBuffer()
ib.writelines(
@ -201,7 +217,9 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str
for model_name in model_names:
ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)")
ib.writeline("\nadd_executable(main main.cpp)")
ib.newline()
if weight_share:
ib.writeline(f"add_executable(main main.cpp {package_name}/data/weights/deduped_weights.o)")
if cuda:
ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)")

View File

@ -8,6 +8,7 @@ import zipfile
from dataclasses import dataclass
from typing import Any, IO, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import sys
import torch
import torch.utils._pytree as pytree
@ -271,31 +272,83 @@ def _package_aoti_files(
if len(all_weights) > 0:
# Dedup weights
grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights)
complete_tensors: list[torch.Tensor] = []
for idx, group in enumerate(grouped_tensors):
filename = f"{WEIGHT_FILENAME_PREFIX}{idx}"
# filename = f"{WEIGHT_FILENAME_PREFIX}{idx}"
model_name, weight_name = get_complete(group, all_weights)
complete_tensor, _ = all_weights[model_name].get_weight(weight_name)
buffer = io.BytesIO()
torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
archive_writer.write_bytes(
os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
)
complete_tensors.append(complete_tensor)
# store tensor as a separate file
# buffer = io.BytesIO()
# torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol)
# archive_writer.write_bytes(
# os.path.join(WEIGHTS_DIR, filename), buffer.getvalue()
# )
from torch._inductor.codecache import _to_bytes, _compile_consts, generate_aot_consts_mapping
from torch._inductor import config as inductor_config
all_cuda = all([t.is_cuda for t in complete_tensors])
weight_bytes = []
offset = 0
for idx, group in enumerate(grouped_tensors):
weight_byte = _to_bytes(complete_tensors[idx], all_cuda)
weight_bytes.append(weight_byte)
size = len(weight_byte)
for model_name, weight_name in group:
_, w_property = all_weights[model_name].get_weight(weight_name)
weights_configs[model_name][weight_name] = (
filename,
offset, # new
size, # new
w_property.shape,
w_property.stride,
w_property.offset,
)
offset += size
# complete_tensors: list[torch.Tensor]
aot_constants = b"".join(t for t in weight_bytes)
# TODO: determine device type more carefully
device_type = "cpu"
if all_cuda:
device_type = "cuda"
# Meta internal AOTInductor CPU
use_relative_path = (
inductor_config.is_fbcode() and device_type == "cpu"
)
specified_sub_dir = WEIGHTS_DIR
# TODO: determing if any constant is mutating
mutating = False
consts_o = _compile_consts(aot_constants, sys.platform, mutating, device_type, True, specified_sub_dir, use_relative_path)
# move consts_o into archive
archive_writer.write_file(
os.path.join(WEIGHTS_DIR, "deduped_weights.o"),
consts_o,
)
for model_name, weights_config in weights_configs.items():
aot_consts_mapping = generate_aot_consts_mapping(weights_config, model_name, offset)
new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, "aot_consts_mapping.h")
archive_writer.write_string(
os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"),
json.dumps(weights_config),
str(new_filepath),
aot_consts_mapping,
)
logger.debug("packaging weights_config for model %s", model_name)
logger.debug(weights_config)
# with open(consts_o, "rb") as f:
# consts_o_data = f.read()
# archive_writer.write_bytes(
# os.path.join(WEIGHTS_DIR, "deduped_weights.o"), consts_o_data
# )
# for model_name, weights_config in weights_configs.items():
# archive_writer.write_string(
# os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"),
# json.dumps(weights_config),
# )
# logger.debug("packaging weights_config for model %s", model_name)
# logger.debug(weights_config)
def _package_exported_programs(