Add config.save.use_pinned_memory_for_d2h to serialization config (#143342)

This was benchmarked with two separate scripts on my A100
(A) Save state_dict of llama3-style model on CUDA to disk with ``torch.save``
(B) Save `ModuleList` of 10 `nn.Linear(10,000, 10,000)` on CUDA to disk with `torch.save`
Timings are an average of 5 runs and benchmark scripts + results are attached

Under both scenarios, we see **~2x speedup in ``torch.save`` time with (``compute_crc32=False`` and ``use_pinned_memory_for_d2h=True``)** compared to the baseline of the current defaults (``compute_crc32=True`` and ``use_pinned_memory_for_d2h=False``

(A)  Save state_dict of llama3-style model on CUDA to disk with ``torch.save`` [[script](https://gist.github.com/mikaylagawarecki/d3a86ea1bb08045d1a839976808d7432)][[results](https://gist.github.com/mikaylagawarecki/f61a4714e5cff703146a1fcb7e0c755c)]

|                                                                                 |  use_pinned_memory_for_d2h=False (Default) |  use_pinned_memory_for_d2h=True |
|-|-|-|
| `compute_crc_32= True`  (Default)| 28.54s | 20.76s |
| `compute_crc_32 = False` | 22.57s |  **14.51s** |

(B) Save `ModuleList` of 10 `nn.Linear(10,000, 10,000)` on CUDA to disk with `torch.save` [[script](https://gist.github.com/mikaylagawarecki/ecbc505436bdd4b5190ef1b3430c12b6)][[results](https://gist.github.com/mikaylagawarecki/4e686bcf030b57de8c3ca74d8f5a88f7)]

|                                                                                 |  use_pinned_memory_for_d2h=False (Default) |  use_pinned_memory_for_d2h=True |
|-|-|-|
| `compute_crc_32= True`  (Default)| 8.38s | 5.53s |
| `compute_crc_32 = False` | 6.94s |  **3.99s** |

Trace of (A) with `use_pinned_memory_for_d2h=True`, `compute_crc32=False`
<img width="1745" alt="Screenshot 2024-12-16 at 7 32 33 PM" src="https://github.com/user-attachments/assets/80b87a8c-5a70-4eb9-ad66-7abc4aa7cc25" />

Baseline trace of (A) with `use_pinned_memory_for_d2h=False`, `compute_crc32=True`
<img width="1799" alt="Screenshot 2024-12-16 at 7 38 20 PM" src="https://github.com/user-attachments/assets/13fa12d1-8f5f-424c-adc4-275b67012927" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143342
Approved by: https://github.com/albanD
ghstack dependencies: #143324
This commit is contained in:
Mikayla Gawarecki
2024-12-20 09:57:24 -08:00
committed by PyTorch MergeBot
parent 3f63b742e6
commit 8e483654cb
4 changed files with 60 additions and 1 deletions

View File

@ -503,6 +503,8 @@ Config
* ``compute_crc32``: whether to compute and write the zip file checksum (Default : ``True``).
See :func:`~torch.serialization.set_crc32_options`.
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.

View File

@ -22,8 +22,10 @@ from copy import deepcopy
from dataclasses import dataclass
from itertools import product
from pathlib import Path
from unittest.mock import patch
import torch
from torch.utils.serialization import config as serialization_config
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
from torch._utils import _rebuild_tensor
from torch._utils_internal import get_file_path_2
@ -4477,6 +4479,45 @@ class TestSerialization(TestCase, SerializationMixin):
loaded_sd = torch.load(f, weights_only=weights_only)
self.assertEqual(sd_save, loaded_sd)
@unittest.skipIf(not torch.accelerator.is_available() or torch.accelerator.current_accelerator().type == 'mps',
"accelerator not available, on mps pin memory allocator is not registered")
def test_use_pinned_memory_for_d2h(self):
device = torch.accelerator.current_accelerator().type
def patched_write_record(self, filename, data, nbytes):
if isinstance(data, (torch.TypedStorage, torch.UntypedStorage)):
if not data.is_pinned(device=device):
raise RuntimeError("Expected storage to be in pinned memory")
return None
sd = torch.nn.Linear(3, 5, device=device).state_dict()
# Test that CUDA actually get moved to pinned memory on CPU
with patch('torch._C.PyTorchFileWriter.write_record', patched_write_record):
with tempfile.NamedTemporaryFile() as f:
with self.assertRaisesRegex(RuntimeError, "Expected storage to be in pinned memory"):
torch.save(sd, f)
with tempfile.NamedTemporaryFile() as f:
pinned_before = serialization_config.save.use_pinned_memory_for_d2h
try:
serialization_config.save.use_pinned_memory_for_d2h = True
torch.save(sd, f)
finally:
serialization_config.save.use_pinned_memory_for_d2h = pinned_before
# Test correctness
with tempfile.NamedTemporaryFile() as f:
pinned_before = serialization_config.save.use_pinned_memory_for_d2h
try:
serialization_config.save.use_pinned_memory_for_d2h = True
torch.save(sd, f)
f.seek(0)
sd_loaded = torch.load(f)
self.assertEqual(sd_loaded, sd)
finally:
serialization_config.save.use_pinned_memory_for_d2h = pinned_before
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):

View File

@ -1206,7 +1206,22 @@ def _save(
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != "cpu":
storage = storage.cpu()
from torch.utils.serialization import config
if (
config.save.use_pinned_memory_for_d2h
and torch.accelerator.is_available()
and torch.accelerator.current_accelerator().type
== storage.device.type
):
new_storage = torch.empty(
num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
).untyped_storage()
new_storage.copy_(storage)
torch.accelerator.current_stream(storage.device.index).synchronize()
storage = new_storage
else:
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
zip_file.write_record(name, storage, num_bytes)

View File

@ -17,6 +17,7 @@ class load:
class save:
compute_crc32: bool = True
use_pinned_memory_for_d2h: bool = False
_install_config_module(sys.modules[__name__])