mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
distributed/serialization: support zero sized tensors (#164198)
Fixes ``` [4] ValueError: both buffer length (0) and count (-1) must not be 0 ``` Test plan: ``` pytest test/distributed/test_serialization.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164198 Approved by: https://github.com/amirafzali
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e5b4249a5
commit
7f4c3e7d2f
@ -95,6 +95,18 @@ class TestSerialization(TestCase):
|
||||
result = _streaming_load(file)
|
||||
torch.testing.assert_close(result, state_dict)
|
||||
|
||||
def test_empty_tensor(self) -> None:
|
||||
state_dict = {
|
||||
"empty": torch.zeros(0, 10),
|
||||
}
|
||||
|
||||
file = BytesIO()
|
||||
_streaming_save(state_dict, file)
|
||||
file.seek(0)
|
||||
|
||||
result = _streaming_load(file, weights_only=False)
|
||||
self.assertEqual(result, state_dict)
|
||||
|
||||
def test_dtensor(self) -> None:
|
||||
dist.init_process_group(
|
||||
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
|
||||
|
@ -57,10 +57,13 @@ class _PseudoZipFile:
|
||||
for entry in entries:
|
||||
data = f.read(entry.length)
|
||||
if entry.is_storage:
|
||||
storage = torch.frombuffer(
|
||||
data,
|
||||
dtype=torch.uint8,
|
||||
).untyped_storage()
|
||||
if entry.length == 0:
|
||||
storage = torch.UntypedStorage(0)
|
||||
else:
|
||||
storage = torch.frombuffer(
|
||||
data,
|
||||
dtype=torch.uint8,
|
||||
).untyped_storage()
|
||||
|
||||
self.records[entry.key] = (
|
||||
storage,
|
||||
|
Reference in New Issue
Block a user