mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add config.save.use_pinned_memory_for_d2h to serialization config (#143342)
This was benchmarked with two separate scripts on my A100 (A) Save state_dict of llama3-style model on CUDA to disk with ``torch.save`` (B) Save `ModuleList` of 10 `nn.Linear(10,000, 10,000)` on CUDA to disk with `torch.save` Timings are an average of 5 runs and benchmark scripts + results are attached Under both scenarios, we see **~2x speedup in ``torch.save`` time with (``compute_crc32=False`` and ``use_pinned_memory_for_d2h=True``)** compared to the baseline of the current defaults (``compute_crc32=True`` and ``use_pinned_memory_for_d2h=False`` (A) Save state_dict of llama3-style model on CUDA to disk with ``torch.save`` [[script](https://gist.github.com/mikaylagawarecki/d3a86ea1bb08045d1a839976808d7432)][[results](https://gist.github.com/mikaylagawarecki/f61a4714e5cff703146a1fcb7e0c755c)] | | use_pinned_memory_for_d2h=False (Default) | use_pinned_memory_for_d2h=True | |-|-|-| | `compute_crc_32= True` (Default)| 28.54s | 20.76s | | `compute_crc_32 = False` | 22.57s | **14.51s** | (B) Save `ModuleList` of 10 `nn.Linear(10,000, 10,000)` on CUDA to disk with `torch.save` [[script](https://gist.github.com/mikaylagawarecki/ecbc505436bdd4b5190ef1b3430c12b6)][[results](https://gist.github.com/mikaylagawarecki/4e686bcf030b57de8c3ca74d8f5a88f7)] | | use_pinned_memory_for_d2h=False (Default) | use_pinned_memory_for_d2h=True | |-|-|-| | `compute_crc_32= True` (Default)| 8.38s | 5.53s | | `compute_crc_32 = False` | 6.94s | **3.99s** | Trace of (A) with `use_pinned_memory_for_d2h=True`, `compute_crc32=False` <img width="1745" alt="Screenshot 2024-12-16 at 7 32 33 PM" src="https://github.com/user-attachments/assets/80b87a8c-5a70-4eb9-ad66-7abc4aa7cc25" /> Baseline trace of (A) with `use_pinned_memory_for_d2h=False`, `compute_crc32=True` <img width="1799" alt="Screenshot 2024-12-16 at 7 38 20 PM" src="https://github.com/user-attachments/assets/13fa12d1-8f5f-424c-adc4-275b67012927" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/143342 Approved by: https://github.com/albanD ghstack dependencies: #143324
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f63b742e6
commit
8e483654cb
@ -503,6 +503,8 @@ Config
|
||||
|
||||
* ``compute_crc32``: whether to compute and write the zip file checksum (Default : ``True``).
|
||||
See :func:`~torch.serialization.set_crc32_options`.
|
||||
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
|
||||
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
|
||||
|
||||
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
|
||||
|
||||
|
@ -22,8 +22,10 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils.serialization import config as serialization_config
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch._utils_internal import get_file_path_2
|
||||
@ -4477,6 +4479,45 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
loaded_sd = torch.load(f, weights_only=weights_only)
|
||||
self.assertEqual(sd_save, loaded_sd)
|
||||
|
||||
@unittest.skipIf(not torch.accelerator.is_available() or torch.accelerator.current_accelerator().type == 'mps',
|
||||
"accelerator not available, on mps pin memory allocator is not registered")
|
||||
def test_use_pinned_memory_for_d2h(self):
|
||||
device = torch.accelerator.current_accelerator().type
|
||||
|
||||
def patched_write_record(self, filename, data, nbytes):
|
||||
if isinstance(data, (torch.TypedStorage, torch.UntypedStorage)):
|
||||
if not data.is_pinned(device=device):
|
||||
raise RuntimeError("Expected storage to be in pinned memory")
|
||||
return None
|
||||
|
||||
sd = torch.nn.Linear(3, 5, device=device).state_dict()
|
||||
|
||||
# Test that CUDA actually get moved to pinned memory on CPU
|
||||
with patch('torch._C.PyTorchFileWriter.write_record', patched_write_record):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected storage to be in pinned memory"):
|
||||
torch.save(sd, f)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
pinned_before = serialization_config.save.use_pinned_memory_for_d2h
|
||||
try:
|
||||
serialization_config.save.use_pinned_memory_for_d2h = True
|
||||
torch.save(sd, f)
|
||||
finally:
|
||||
serialization_config.save.use_pinned_memory_for_d2h = pinned_before
|
||||
|
||||
# Test correctness
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
pinned_before = serialization_config.save.use_pinned_memory_for_d2h
|
||||
try:
|
||||
serialization_config.save.use_pinned_memory_for_d2h = True
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd)
|
||||
finally:
|
||||
serialization_config.save.use_pinned_memory_for_d2h = pinned_before
|
||||
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
|
@ -1206,7 +1206,22 @@ def _save(
|
||||
# this means to that to get tensors serialized, you need to implement
|
||||
# .cpu() on the underlying Storage
|
||||
if storage.device.type != "cpu":
|
||||
storage = storage.cpu()
|
||||
from torch.utils.serialization import config
|
||||
|
||||
if (
|
||||
config.save.use_pinned_memory_for_d2h
|
||||
and torch.accelerator.is_available()
|
||||
and torch.accelerator.current_accelerator().type
|
||||
== storage.device.type
|
||||
):
|
||||
new_storage = torch.empty(
|
||||
num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
|
||||
).untyped_storage()
|
||||
new_storage.copy_(storage)
|
||||
torch.accelerator.current_stream(storage.device.index).synchronize()
|
||||
storage = new_storage
|
||||
else:
|
||||
storage = storage.cpu()
|
||||
# Now that it is on the CPU we can directly copy it into the zip file
|
||||
zip_file.write_record(name, storage, num_bytes)
|
||||
|
||||
|
@ -17,6 +17,7 @@ class load:
|
||||
|
||||
class save:
|
||||
compute_crc32: bool = True
|
||||
use_pinned_memory_for_d2h: bool = False
|
||||
|
||||
|
||||
_install_config_module(sys.modules[__name__])
|
||||
|
Reference in New Issue
Block a user