Files
pytorch/torch/distributed/_checkpointable.py
Aaron Gokaslan 6b1211df29 [BE]: Backport runtime_checkable perf improvements/behavior from 3.12 (#155130)
Backports some behavior changes and performance improvements with runtime_checkable in 3.12 to older versions of Python. Should be free performance improvement on typing checking protocols since everything works on Python 3.12.

The difference between the two versions of runtime_checkable is [these lines](40e22ebb2c/src/typing_extensions.py (L800-L823)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155130
Approved by: https://github.com/rec, https://github.com/aorenste
2025-06-06 13:28:05 +00:00

38 lines
1.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
from typing_extensions import Protocol, runtime_checkable
import torch
@runtime_checkable
class _Checkpointable(Protocol): # noqa: PYI046
"""
Interface for checkpointable objects.
Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly.
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
"""
def __create_write_items__(self, fqn: str, object: object) -> list[object]:
"""
Return a list of WriteItems based on object's contents.
"""
raise NotImplementedError(
"_Checkpointable._create_write_items is not implemented"
)
def __create_chunk_list__(self) -> list[object]:
"""
Return a list of `ChunkStorageMetadata` based on object's contents.
"""
raise NotImplementedError(
"_Checkpointable._create_chunk_list is not implemented"
)
def __get_tensor_shard__(self, index: int) -> torch.Tensor:
"""
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
"""
raise NotImplementedError(
"_Checkpointable._get_tensor_shard is not implemented"
)