Reduce random reads for offset metadata when calling torch.load under FakeTensorMode (#157931)

We already test the `_get_offset` functionality with that TORCH_SERIALIZATION_DEBUG flag that is set in CI, so I didn't add more testing specifically for FakeTensor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157931
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-07-17 19:06:56 +00:00
committed by PyTorch MergeBot
parent af6624023e
commit 41b2c4d119

View File

@ -1988,7 +1988,7 @@ def _load(
# for a given key.
offsets[name] = storage_offset
# Increment current_offset of offset where next zipfile header starts
# Increment current_offset to offset where next zipfile header starts
current_offset = storage_offset + numel
# add size of data descriptor after payload
if numel > 0:
@ -2004,7 +2004,10 @@ def _load(
if torch._guards.detect_fake_mode(None) is not None:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes, device="meta")
storage._checkpoint_offset = zip_file.get_record_offset(name)
if can_calculate_storage_offsets:
storage._checkpoint_offset = _get_offset(key, name, numel)
else:
storage._checkpoint_offset = zip_file.get_record_offset(name)
elif _serialization_tls.skip_data:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes)