mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This is the result of applying the ruff `UP035` check. `Callable` is imported from `collections.abc` instead of `typing`. `TypeAlias` is also imported from `typing`. This PR is the follow-up of #163947. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054 Approved by: https://github.com/ezyang, https://github.com/Skylion007
129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
import base64
|
|
import zlib
|
|
from collections.abc import Callable, Iterable
|
|
from typing import Generic, TypeVar
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
_ENCODING_VERSION: int = 1
|
|
|
|
__all__ = ["AppendingByteSerializer"]
|
|
|
|
|
|
#######################################
|
|
# Helper classes
|
|
#######################################
|
|
|
|
CHECKSUM_DIGEST_SIZE = 4
|
|
|
|
|
|
class BytesWriter:
|
|
def __init__(self) -> None:
|
|
# Reserve CHECKSUM_DIGEST_SIZE bytes for checksum
|
|
self._data = bytearray(CHECKSUM_DIGEST_SIZE)
|
|
|
|
def write_uint64(self, i: int) -> None:
|
|
self._data.extend(i.to_bytes(8, byteorder="big", signed=False))
|
|
|
|
def write_str(self, s: str) -> None:
|
|
payload = base64.b64encode(s.encode("utf-8"))
|
|
self.write_bytes(payload)
|
|
|
|
def write_bytes(self, b: bytes) -> None:
|
|
self.write_uint64(len(b))
|
|
self._data.extend(b)
|
|
|
|
def to_bytes(self) -> bytes:
|
|
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
|
4, byteorder="big", signed=False
|
|
)
|
|
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
|
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
|
|
return bytes(self._data)
|
|
|
|
|
|
class BytesReader:
|
|
def __init__(self, data: bytes) -> None:
|
|
# Check for data corruption
|
|
assert len(data) >= CHECKSUM_DIGEST_SIZE
|
|
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
|
4, byteorder="big", signed=False
|
|
)
|
|
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
|
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
|
|
raise RuntimeError(
|
|
"Bytes object is corrupted, checksum does not match. "
|
|
f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}"
|
|
)
|
|
|
|
self._data = data
|
|
self._i = CHECKSUM_DIGEST_SIZE
|
|
|
|
def is_finished(self) -> bool:
|
|
return len(self._data) == self._i
|
|
|
|
def read_uint64(self) -> int:
|
|
result = int.from_bytes(
|
|
self._data[self._i : self._i + 8], byteorder="big", signed=False
|
|
)
|
|
self._i += 8
|
|
return result
|
|
|
|
def read_str(self) -> str:
|
|
return base64.b64decode(self.read_bytes()).decode("utf-8")
|
|
|
|
def read_bytes(self) -> bytes:
|
|
size = self.read_uint64()
|
|
result = self._data[self._i : self._i + size]
|
|
self._i += size
|
|
return result
|
|
|
|
|
|
#######################################
|
|
# AppendingByteSerializer
|
|
#######################################
|
|
|
|
|
|
class AppendingByteSerializer(Generic[T]):
|
|
"""
|
|
Provides efficient serialization and deserialization of list of bytes
|
|
Note that this does not provide any guarantees around byte order
|
|
"""
|
|
|
|
_serialize_fn: Callable[[BytesWriter, T], None]
|
|
_writer: BytesWriter
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
serialize_fn: Callable[[BytesWriter, T], None],
|
|
) -> None:
|
|
self._serialize_fn = serialize_fn
|
|
self.clear()
|
|
|
|
def clear(self) -> None:
|
|
self._writer = BytesWriter()
|
|
# First 8-bytes are for version
|
|
self._writer.write_uint64(_ENCODING_VERSION)
|
|
|
|
def append(self, data: T) -> None:
|
|
self._serialize_fn(self._writer, data)
|
|
|
|
def extend(self, elems: Iterable[T]) -> None:
|
|
for elem in elems:
|
|
self.append(elem)
|
|
|
|
def to_bytes(self) -> bytes:
|
|
return self._writer.to_bytes()
|
|
|
|
@staticmethod
|
|
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
|
|
reader = BytesReader(data)
|
|
assert reader.read_uint64() == _ENCODING_VERSION
|
|
|
|
result: list[T] = []
|
|
while not reader.is_finished():
|
|
result.append(deserialize_fn(reader))
|
|
return result
|