mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Enbale mmaped weights when CUDA is used (#124346)
By refactoring the logic that returns the start to constant pointer into `_get_constants_start()` method and call it from both CUDA and CPU readers It has no runtime impact, but export time is down from 10m to 3m if mmaped weights are used on AWS p4d.24xlarge Pull Request resolved: https://github.com/pytorch/pytorch/pull/124346 Approved by: https://github.com/mikekgfb, https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
87f44d70b1
commit
1ba85b34dd
@ -1827,10 +1827,8 @@ class AotCodeCompiler:
|
||||
if name not in graph.folded_constants
|
||||
)
|
||||
# TODO: Fix mmap weights with cuda
|
||||
use_mmap_weights = (
|
||||
not cuda and not config.is_fbcode() and consts_size > 2_000_000_000
|
||||
)
|
||||
if config.aot_inductor.force_mmap_weights and not cuda:
|
||||
use_mmap_weights = not config.is_fbcode() and consts_size > 2_000_000_000
|
||||
if config.aot_inductor.force_mmap_weights:
|
||||
use_mmap_weights = True
|
||||
compile_cmd = cpp_compile_command(
|
||||
input=input_path,
|
||||
|
@ -268,51 +268,16 @@ class AOTInductorModelBase {
|
||||
if (!skip_copy) {
|
||||
AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
|
||||
internal_ptr,
|
||||
_binary_constants_bin_start + bytes_read,
|
||||
_get_constants_start() + bytes_read,
|
||||
data_size,
|
||||
cudaMemcpyHostToDevice));
|
||||
}
|
||||
return internal_ptr;
|
||||
#elif USE_MMAP_SELF
|
||||
// get pointer to constant which is packed in model during compile time.
|
||||
AOTI_RUNTIME_CHECK(!skip_copy, "pure cpu mode doesn't support skip copy");
|
||||
if (!self_mmap) {
|
||||
Dl_info dl_info;
|
||||
// get pointer to constant which are appended to the binary
|
||||
AOTI_RUNTIME_CHECK(
|
||||
dladdr(__func__, &dl_info), "Can't find shared library name");
|
||||
int fd = open(dl_info.dli_fname, O_RDONLY);
|
||||
AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened");
|
||||
auto fsize = lseek(fd, 0, SEEK_END);
|
||||
auto weights_size =
|
||||
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
|
||||
auto magic_number =
|
||||
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[1];
|
||||
auto weights_offset = fsize - weights_size;
|
||||
AOTI_RUNTIME_CHECK(
|
||||
(weights_offset & 0x3fff) == 0,
|
||||
"weights_offset must be aligned to 16K boundary");
|
||||
auto ptr = mmap(
|
||||
NULL,
|
||||
weights_size,
|
||||
PROT_READ | PROT_WRITE,
|
||||
MAP_PRIVATE,
|
||||
fd,
|
||||
weights_offset);
|
||||
close(fd);
|
||||
AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed");
|
||||
self_mmap = static_cast<uint8_t*>(ptr);
|
||||
AOTI_RUNTIME_CHECK(
|
||||
reinterpret_cast<uint64_t*>(
|
||||
self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number,
|
||||
"Weigths data seems corrupt");
|
||||
}
|
||||
return self_mmap + bytes_read;
|
||||
|
||||
#else // !USE_CUDA&& !USE_MMAP_SELF
|
||||
#else
|
||||
// get pointer to constant which is packed in model during compile time.
|
||||
AOTI_RUNTIME_CHECK(!skip_copy, "pure cpu mode doesn't support skip copy");
|
||||
return const_cast<uint8_t*>(_binary_constants_bin_start) + bytes_read;
|
||||
return _get_constants_start() + bytes_read;
|
||||
#endif // USE_CUDA
|
||||
}
|
||||
|
||||
@ -470,6 +435,45 @@ class AOTInductorModelBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
uint8_t* _get_constants_start() {
|
||||
#ifndef USE_MMAP_SELF
|
||||
return const_cast<uint8_t*>(_binary_constants_bin_start);
|
||||
#else
|
||||
if (self_mmap) {
|
||||
return self_mmap;
|
||||
}
|
||||
Dl_info dl_info;
|
||||
// get pointer to constant which are appended to the binary
|
||||
AOTI_RUNTIME_CHECK(
|
||||
dladdr(__func__, &dl_info), "Can't find shared library name");
|
||||
int fd = open(dl_info.dli_fname, O_RDONLY);
|
||||
AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened");
|
||||
auto fsize = lseek(fd, 0, SEEK_END);
|
||||
auto weights_size =
|
||||
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[0];
|
||||
auto magic_number =
|
||||
reinterpret_cast<const uint64_t*>(_binary_constants_bin_start)[1];
|
||||
auto weights_offset = fsize - weights_size;
|
||||
AOTI_RUNTIME_CHECK(
|
||||
(weights_offset & 0x3fff) == 0,
|
||||
"weights_offset must be aligned to 16K boundary");
|
||||
auto ptr = mmap(
|
||||
NULL,
|
||||
weights_size,
|
||||
PROT_READ | PROT_WRITE,
|
||||
MAP_PRIVATE,
|
||||
fd,
|
||||
weights_offset);
|
||||
close(fd);
|
||||
AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed");
|
||||
self_mmap = static_cast<uint8_t*>(ptr);
|
||||
AOTI_RUNTIME_CHECK(
|
||||
reinterpret_cast<uint64_t*>(
|
||||
self_mmap + weights_size - sizeof(uint64_t))[0] == magic_number,
|
||||
"Weigths data seems corrupt");
|
||||
return self_mmap;
|
||||
#endif
|
||||
}
|
||||
struct ParamInfo {
|
||||
const char* name = nullptr;
|
||||
};
|
||||
|
Reference in New Issue
Block a user