mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make record/storage alignment in torch.save configurable (#147788)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787
This commit is contained in:
committed by
PyTorch MergeBot
parent
209977e6e5
commit
be0ceee1c3
@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertTrue(opened_zipfile.has_record(".format_version"))
|
||||
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
|
||||
|
||||
def test_storage_alignment(self):
|
||||
sd = torch.nn.Linear(10, 10).state_dict()
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with FakeTensorMode():
|
||||
sd_fake = torch.load(f)
|
||||
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832)
|
||||
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344)
|
||||
|
||||
storage_alignment_before = serialization_config.save.storage_alignment
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
try:
|
||||
serialization_config.save.storage_alignment = 4096
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with FakeTensorMode():
|
||||
sd_fake = torch.load(f)
|
||||
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480)
|
||||
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd)
|
||||
finally:
|
||||
serialization_config.save.storage_alignment = storage_alignment_before
|
||||
|
||||
|
||||
@parametrize('path_type', (str, Path))
|
||||
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
|
||||
def test_mmap_load_offset_calculation(self, path_type):
|
||||
|
Reference in New Issue
Block a user