Supporting non-tensor-data write_size in planner write items. (#149699)

Summary:
1\ The current write item structure does not contain the amount of data that needs to be written.
2\ the planner.item already has a size primitive 'tensor_storage_size'. https://fburl.com/code/7a0gsmw7 But only for tensors.
3\ Right now, the only way the writer layer get hold of this property (fro non tensor data)
first do a lookup in to the actual tensor/bytes
then calculate the nbytes.
This change introduce a way to capture non-tensor data size within a write-plan item.

Test Plan: Existing UT.

Differential Revision: D71599725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149699
Approved by: https://github.com/MeetVadakkanchery
This commit is contained in:
Pradeep Fernando
2025-03-21 18:09:09 +00:00
committed by PyTorch MergeBot
parent f7d1b966c2
commit 1b08aaeafe
2 changed files with 12 additions and 0 deletions

View File

@ -85,6 +85,9 @@ The following types define the planner interface used during checkpoint:
.. autoclass:: torch.distributed.checkpoint.planner.WriteItem
:members:
.. autoclass:: torch.distributed.checkpoint.planner.BytesIOWriteData
:members:
We provide a filesystem based storage layer:
.. autoclass:: torch.distributed.checkpoint.FileSystemReader

View File

@ -20,6 +20,7 @@ from torch.distributed.checkpoint.metadata import (
__all__ = [
"WriteItemType",
"LoadItemType",
"BytesIOWriteData",
"TensorWriteData",
"WriteItem",
"ReadItem",
@ -41,6 +42,11 @@ class LoadItemType(Enum):
BYTE_IO = auto()
@dataclass(frozen=True)
class BytesIOWriteData:
nbytes: int
@dataclass(frozen=True)
class TensorWriteData:
chunk: ChunkStorageMetadata
@ -55,6 +61,9 @@ class WriteItem:
index: MetadataIndex
type: WriteItemType
# Size of bytesIO data to be written.
bytes_io_data: Optional[BytesIOWriteData] = None
# Value present if it's a tensor write
tensor_data: Optional[TensorWriteData] = None