From e4a1a16befb42b03c3cd8c961130a292b6875cae Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 24 Apr 2025 14:25:53 -0700 Subject: [PATCH] Check integrity of bytes in AppendingByteSerializer (#152139) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152139 Approved by: https://github.com/zou3519 --- test/test_appending_byte_serializer.py | 13 +++++++++ torch/_inductor/standalone_compile.py | 2 +- torch/utils/_appending_byte_serializer.py | 32 ++++++++++++++++++----- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/test/test_appending_byte_serializer.py b/test/test_appending_byte_serializer.py index e650fad1eac7..d21e1d694957 100644 --- a/test/test_appending_byte_serializer.py +++ b/test/test_appending_byte_serializer.py @@ -78,6 +78,19 @@ class TestAppendingByteSerializer(TestCase): ), ) + def test_checksum(self) -> None: + writer = BytesWriter() + writer.write_str("test") + b = writer.to_bytes() + b = bytearray(b) + b[0:1] = b"\x00" + b = bytes(b) + + with self.assertRaisesRegex( + RuntimeError, r"Bytes object is corrupted, checksum does not match.*" + ): + BytesReader(b) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index f07231c29e2d..389f9ed8c7c7 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -81,7 +81,7 @@ class CompiledArtifact: from .codecache import torch_key - writer = BytesWriter(0) + writer = BytesWriter() writer.write_bytes(torch_key()) writer.write_str(key) writer.write_bytes(artifact_bytes) diff --git a/torch/utils/_appending_byte_serializer.py b/torch/utils/_appending_byte_serializer.py index b5cb54aa1e8e..91936ab6fc06 100644 --- a/torch/utils/_appending_byte_serializer.py +++ b/torch/utils/_appending_byte_serializer.py @@ -1,4 +1,5 @@ import base64 +import zlib from collections.abc import Iterable from typing import Callable, Generic, TypeVar @@ -14,10 +15,13 @@ __all__ = ["AppendingByteSerializer"] # Helper classes ####################################### +CHECKSUM_DIGEST_SIZE = 4 + class BytesWriter: - def __init__(self, preallocate_size: int) -> None: - self._data = bytearray(preallocate_size) + def __init__(self) -> None: + # Reserve CHECKSUM_DIGEST_SIZE bytes for checksum + self._data = bytearray(CHECKSUM_DIGEST_SIZE) def write_uint64(self, i: int) -> None: self._data.extend(i.to_bytes(8, byteorder="big", signed=False)) @@ -31,13 +35,30 @@ class BytesWriter: self._data.extend(b) def to_bytes(self) -> bytes: + digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes( + 4, byteorder="big", signed=False + ) + assert len(digest) == CHECKSUM_DIGEST_SIZE + self._data[0:CHECKSUM_DIGEST_SIZE] = digest return bytes(self._data) class BytesReader: def __init__(self, data: bytes) -> None: + # Check for data corruption + assert len(data) >= CHECKSUM_DIGEST_SIZE + digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes( + 4, byteorder="big", signed=False + ) + assert len(digest) == CHECKSUM_DIGEST_SIZE + if data[0:CHECKSUM_DIGEST_SIZE] != digest: + raise RuntimeError( + "Bytes object is corrupted, checksum does not match. " + f"Expected: {data[0:CHECKSUM_DIGEST_SIZE]!r}, Got: {digest!r}" + ) + self._data = data - self._i = 0 + self._i = CHECKSUM_DIGEST_SIZE def is_finished(self) -> bool: return len(self._data) == self._i @@ -72,20 +93,17 @@ class AppendingByteSerializer(Generic[T]): _serialize_fn: Callable[[BytesWriter, T], None] _writer: BytesWriter - _preallocate_size: int def __init__( self, *, serialize_fn: Callable[[BytesWriter, T], None], - preallocate_size: int = 0, ) -> None: self._serialize_fn = serialize_fn - self._preallocate_size = preallocate_size self.clear() def clear(self) -> None: - self._writer = BytesWriter(preallocate_size=self._preallocate_size) + self._writer = BytesWriter() # First 8-bytes are for version self._writer.write_uint64(_ENCODING_VERSION)