mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Reduce random reads for offset metadata when calling torch.load under FakeTensorMode (#157931)
We already test the `_get_offset` functionality with that TORCH_SERIALIZATION_DEBUG flag that is set in CI, so I didn't add more testing specifically for FakeTensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/157931 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
af6624023e
commit
41b2c4d119
@ -1988,7 +1988,7 @@ def _load(
|
||||
# for a given key.
|
||||
offsets[name] = storage_offset
|
||||
|
||||
# Increment current_offset of offset where next zipfile header starts
|
||||
# Increment current_offset to offset where next zipfile header starts
|
||||
current_offset = storage_offset + numel
|
||||
# add size of data descriptor after payload
|
||||
if numel > 0:
|
||||
@ -2004,7 +2004,10 @@ def _load(
|
||||
if torch._guards.detect_fake_mode(None) is not None:
|
||||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
storage = torch.UntypedStorage(nbytes, device="meta")
|
||||
storage._checkpoint_offset = zip_file.get_record_offset(name)
|
||||
if can_calculate_storage_offsets:
|
||||
storage._checkpoint_offset = _get_offset(key, name, numel)
|
||||
else:
|
||||
storage._checkpoint_offset = zip_file.get_record_offset(name)
|
||||
elif _serialization_tls.skip_data:
|
||||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
storage = torch.UntypedStorage(nbytes)
|
||||
|
Reference in New Issue
Block a user