Check integrity of bytes in AppendingByteSerializer (#152139)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152139
Approved by: https://github.com/zou3519
This commit is contained in:
Oguz Ulgen
2025-04-24 14:25:53 -07:00
committed by PyTorch MergeBot
parent 9480ed4cd3
commit e4a1a16bef
3 changed files with 39 additions and 8 deletions

View File

@ -78,6 +78,19 @@ class TestAppendingByteSerializer(TestCase):
),
)
def test_checksum(self) -> None:
writer = BytesWriter()
writer.write_str("test")
b = writer.to_bytes()
b = bytearray(b)
b[0:1] = b"\x00"
b = bytes(b)
with self.assertRaisesRegex(
RuntimeError, r"Bytes object is corrupted, checksum does not match.*"
):
BytesReader(b)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -81,7 +81,7 @@ class CompiledArtifact:
from .codecache import torch_key
writer = BytesWriter(0)
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)

View File

@ -1,4 +1,5 @@
import base64
import zlib
from collections.abc import Iterable
from typing import Callable, Generic, TypeVar
@ -14,10 +15,13 @@ __all__ = ["AppendingByteSerializer"]
# Helper classes
#######################################
CHECKSUM_DIGEST_SIZE = 4
class BytesWriter:
def __init__(self, preallocate_size: int) -> None:
self._data = bytearray(preallocate_size)
def __init__(self) -> None:
# Reserve CHECKSUM_DIGEST_SIZE bytes for checksum
self._data = bytearray(CHECKSUM_DIGEST_SIZE)
def write_uint64(self, i: int) -> None:
self._data.extend(i.to_bytes(8, byteorder="big", signed=False))
@ -31,13 +35,30 @@ class BytesWriter:
self._data.extend(b)
def to_bytes(self) -> bytes:
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
assert len(digest) == CHECKSUM_DIGEST_SIZE
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
return bytes(self._data)
class BytesReader:
def __init__(self, data: bytes) -> None:
# Check for data corruption
assert len(data) >= CHECKSUM_DIGEST_SIZE
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
4, byteorder="big", signed=False
)
assert len(digest) == CHECKSUM_DIGEST_SIZE
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
raise RuntimeError(
"Bytes object is corrupted, checksum does not match. "
f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}"
)
self._data = data
self._i = 0
self._i = CHECKSUM_DIGEST_SIZE
def is_finished(self) -> bool:
return len(self._data) == self._i
@ -72,20 +93,17 @@ class AppendingByteSerializer(Generic[T]):
_serialize_fn: Callable[[BytesWriter, T], None]
_writer: BytesWriter
_preallocate_size: int
def __init__(
self,
*,
serialize_fn: Callable[[BytesWriter, T], None],
preallocate_size: int = 0,
) -> None:
self._serialize_fn = serialize_fn
self._preallocate_size = preallocate_size
self.clear()
def clear(self) -> None:
self._writer = BytesWriter(preallocate_size=self._preallocate_size)
self._writer = BytesWriter()
# First 8-bytes are for version
self._writer.write_uint64(_ENCODING_VERSION)