[AOTI] Enbale mmaped weights when CUDA is used (#124346)

By refactoring the logic that returns the start to constant pointer into `_get_constants_start()` method and call it from both CUDA and CPU readers

It has no runtime impact, but export time is down from 10m to 3m if mmaped weights are used on AWS p4d.24xlarge

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124346
Approved by: https://github.com/mikekgfb, https://github.com/desertfire
This commit is contained in:
Nikita Shulga
2024-04-19 04:47:27 +00:00
committed by PyTorch MergeBot
parent 87f44d70b1
commit 1ba85b34dd
2 changed files with 44 additions and 42 deletions

View File

@ -1827,10 +1827,8 @@ class AotCodeCompiler:
if name not in graph.folded_constants
)
# TODO: Fix mmap weights with cuda
use_mmap_weights = (
not cuda and not config.is_fbcode() and consts_size > 2_000_000_000
)
if config.aot_inductor.force_mmap_weights and not cuda:
use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
if config.aot_inductor.force_mmap_weights:
use_mmap_weights = True
compile_cmd = cpp_compile_command(
input=input_path,

View File

@ -268,51 +268,16 @@ class AOTInductorModelBase {
if (!skip_copy) {
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
internal_ptr,
_binary_constants_bin_start + bytes_read,
_get_constants_start() + bytes_read,
data_size,
cudaMemcpyHostToDevice));
}
return internal_ptr;
#elif USE_MMAP_SELF
// get pointer to constant which is packed in model during compile time.
AOTI_RUNTIME_CHECK(!skip_copy, "pure cpu mode doesn't support skip copy");
if (!self_mmap) {
Dl_info dl_info;
// get pointer to constant which are appended to the binary
AOTI_RUNTIME_CHECK(
dladdr(__func__, &dl_info), "Can't find shared library name");
int fd = open(dl_info.dli_fname, O_RDONLY);
AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened");
auto fsize = lseek(fd, 0, SEEK_END);
auto weights_size =
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
auto magic_number =
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[1];
auto weights_offset = fsize - weights_size;
AOTI_RUNTIME_CHECK(
(weights_offset & 0x3fff) == 0,
"weights_offset must be aligned to 16K boundary");
auto ptr = mmap(
NULL,
weights_size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE,
fd,
weights_offset);
close(fd);
AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed");
self_mmap = static_cast<uint8_t*>(ptr);
AOTI_RUNTIME_CHECK(
reinterpret_cast<uint64_t*>(
self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number,
"Weigths data seems corrupt");
}
return self_mmap + bytes_read;
#else // !USE_CUDA&& !USE_MMAP_SELF
#else
// get pointer to constant which is packed in model during compile time.
AOTI_RUNTIME_CHECK(!skip_copy, "pure cpu mode doesn't support skip copy");
return const_cast<uint8_t*>(_binary_constants_bin_start) + bytes_read;
return _get_constants_start() + bytes_read;
#endif // USE_CUDA
}
@ -470,6 +435,45 @@ class AOTInductorModelBase {
}
protected:
uint8_t* _get_constants_start() {
#ifndef USE_MMAP_SELF
return const_cast<uint8_t*>(_binary_constants_bin_start);
#else
if (self_mmap) {
return self_mmap;
}
Dl_info dl_info;
// get pointer to constant which are appended to the binary
AOTI_RUNTIME_CHECK(
dladdr(__func__, &dl_info), "Can't find shared library name");
int fd = open(dl_info.dli_fname, O_RDONLY);
AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened");
auto fsize = lseek(fd, 0, SEEK_END);
auto weights_size =
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
auto magic_number =
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[1];
auto weights_offset = fsize - weights_size;
AOTI_RUNTIME_CHECK(
(weights_offset & 0x3fff) == 0,
"weights_offset must be aligned to 16K boundary");
auto ptr = mmap(
NULL,
weights_size,
PROT_READ | PROT_WRITE,
MAP_PRIVATE,
fd,
weights_offset);
close(fd);
AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed");
self_mmap = static_cast<uint8_t*>(ptr);
AOTI_RUNTIME_CHECK(
reinterpret_cast<uint64_t*>(
self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number,
"Weigths data seems corrupt");
return self_mmap;
#endif
}
struct ParamInfo {
const char* name = nullptr;
};