Make record/storage alignment in torch.save configurable (#147788)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788
Approved by: https://github.com/albanD
ghstack dependencies: #147786, #147787
This commit is contained in:
Mikayla Gawarecki
2025-03-06 08:50:55 +00:00
committed by PyTorch MergeBot
parent 209977e6e5
commit be0ceee1c3
8 changed files with 132 additions and 38 deletions

View File

@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertTrue(opened_zipfile.has_record(".format_version"))
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
def test_storage_alignment(self):
sd = torch.nn.Linear(10, 10).state_dict()
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
f.seek(0)
with FakeTensorMode():
sd_fake = torch.load(f)
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832)
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344)
storage_alignment_before = serialization_config.save.storage_alignment
with tempfile.NamedTemporaryFile() as f:
try:
serialization_config.save.storage_alignment = 4096
torch.save(sd, f)
f.seek(0)
with FakeTensorMode():
sd_fake = torch.load(f)
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480)
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576)
f.seek(0)
sd_loaded = torch.load(f)
self.assertEqual(sd_loaded, sd)
finally:
serialization_config.save.storage_alignment = storage_alignment_before
@parametrize('path_type', (str, Path))
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
def test_mmap_load_offset_calculation(self, path_type):