mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
As title
In windows, we cannot modify the .dll to append weights at the end, the windows .dll loader will complain it's not a valid .dll file. So we store the weight blob as a separete file.
1. We add the following API which allows passing in a pointer to the weight blob and get the size of the weight blob.
```cpp
AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantsBlobSize(
AOTInductorModelContainerHandle container_handle,
uint64_t* ret_size);
// Load weights from a single blob in weight_blob_ptr
AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsFromBlob(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr);
```
2. We also add a method in ModelContainerRunner to load the weight:
If the runner see that there is a `.blob` file in the package, if will mmap the .blob file and use the content to load the constants.
3. We also add the `USE_MMAP_EXTERNAL` macro. When this macro is defined, the model expects to load the weights from external mmap'd weights.
Test Plan:
```
buck run @mode/dev-nosan caffe2/test/inductor:test_aot_inductor -- -r test_large_mmaped_weights_on_disk
```
Also tested for windows-cross compilation with 6542566585/demo/main_voxtral.cpp
```
Loaded model.dll
audio_encoder loaded
C:\Users\shangdiy\source\repos\torchnative\demo\token_embedding\data\aotinductor\model\model.wrapper.so
Loaded model.dll
token_embedding loaded
C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper.so
Loaded model.dll
Loading weights from C:\Users\shangdiy\source\repos\torchnative\demo\text_decoder\data\aotinductor\model\model.wrapper_weights.blob
text_decoder loaded
Load latency (ms):
audio_encoder: 1011.234
archive extraction: 0.000
.so loading: 1011.197
token_embedding: 525.773
archive extraction: 0.000
.so loading: 525.704
text_decoder: 3324.130
archive extraction: 0.000
.so loading: 3323.979
Run latency (ms):
audio_encoder: 285.958
audio_encoder output: dtype=bfloat16, shape=[1, 1125, 3072], numel=3456000
token_embedding: 6.676
token_embedding output: dtype=bfloat16, shape=[1, 1138, 3072], numel=3495936
text_decoder: 576.519
text_decoder output: dtype=bfloat16, shape=[1, 1138, 131072], numel=149159936
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164526
Approved by: https://github.com/desertfire