mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Check integrity of bytes in AppendingByteSerializer (#152139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152139 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
9480ed4cd3
commit
e4a1a16bef
@ -78,6 +78,19 @@ class TestAppendingByteSerializer(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_checksum(self) -> None:
|
||||
writer = BytesWriter()
|
||||
writer.write_str("test")
|
||||
b = writer.to_bytes()
|
||||
b = bytearray(b)
|
||||
b[0:1] = b"\x00"
|
||||
b = bytes(b)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Bytes object is corrupted, checksum does not match.*"
|
||||
):
|
||||
BytesReader(b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
@ -81,7 +81,7 @@ class CompiledArtifact:
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter(0)
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import zlib
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
@ -14,10 +15,13 @@ __all__ = ["AppendingByteSerializer"]
|
||||
# Helper classes
|
||||
#######################################
|
||||
|
||||
CHECKSUM_DIGEST_SIZE = 4
|
||||
|
||||
|
||||
class BytesWriter:
|
||||
def __init__(self, preallocate_size: int) -> None:
|
||||
self._data = bytearray(preallocate_size)
|
||||
def __init__(self) -> None:
|
||||
# Reserve CHECKSUM_DIGEST_SIZE bytes for checksum
|
||||
self._data = bytearray(CHECKSUM_DIGEST_SIZE)
|
||||
|
||||
def write_uint64(self, i: int) -> None:
|
||||
self._data.extend(i.to_bytes(8, byteorder="big", signed=False))
|
||||
@ -31,13 +35,30 @@ class BytesWriter:
|
||||
self._data.extend(b)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
|
||||
return bytes(self._data)
|
||||
|
||||
|
||||
class BytesReader:
|
||||
def __init__(self, data: bytes) -> None:
|
||||
# Check for data corruption
|
||||
assert len(data) >= CHECKSUM_DIGEST_SIZE
|
||||
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
|
||||
4, byteorder="big", signed=False
|
||||
)
|
||||
assert len(digest) == CHECKSUM_DIGEST_SIZE
|
||||
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
|
||||
raise RuntimeError(
|
||||
"Bytes object is corrupted, checksum does not match. "
|
||||
f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}"
|
||||
)
|
||||
|
||||
self._data = data
|
||||
self._i = 0
|
||||
self._i = CHECKSUM_DIGEST_SIZE
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return len(self._data) == self._i
|
||||
@ -72,20 +93,17 @@ class AppendingByteSerializer(Generic[T]):
|
||||
|
||||
_serialize_fn: Callable[[BytesWriter, T], None]
|
||||
_writer: BytesWriter
|
||||
_preallocate_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
serialize_fn: Callable[[BytesWriter, T], None],
|
||||
preallocate_size: int = 0,
|
||||
) -> None:
|
||||
self._serialize_fn = serialize_fn
|
||||
self._preallocate_size = preallocate_size
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._writer = BytesWriter(preallocate_size=self._preallocate_size)
|
||||
self._writer = BytesWriter()
|
||||
# First 8-bytes are for version
|
||||
self._writer.write_uint64(_ENCODING_VERSION)
|
||||
|
||||
|
Reference in New Issue
Block a user