distributed/serialization: support zero sized tensors (#164198)

Fixes
```
[4] ValueError: both buffer length (0) and count (-1) must not be 0
```

Test plan:

```
pytest test/distributed/test_serialization.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164198
Approved by: https://github.com/amirafzali
This commit is contained in:
Tristan Rice
2025-09-30 08:11:26 +00:00
committed by PyTorch MergeBot
parent 6e5b4249a5
commit 7f4c3e7d2f
2 changed files with 19 additions and 4 deletions

View File

@ -95,6 +95,18 @@ class TestSerialization(TestCase):
result = _streaming_load(file)
torch.testing.assert_close(result, state_dict)
def test_empty_tensor(self) -> None:
state_dict = {
"empty": torch.zeros(0, 10),
}
file = BytesIO()
_streaming_save(state_dict, file)
file.seek(0)
result = _streaming_load(file, weights_only=False)
self.assertEqual(result, state_dict)
def test_dtensor(self) -> None:
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()

View File

@ -57,10 +57,13 @@ class _PseudoZipFile:
for entry in entries:
data = f.read(entry.length)
if entry.is_storage:
storage = torch.frombuffer(
data,
dtype=torch.uint8,
).untyped_storage()
if entry.length == 0:
storage = torch.UntypedStorage(0)
else:
storage = torch.frombuffer(
data,
dtype=torch.uint8,
).untyped_storage()
self.records[entry.key] = (
storage,