Files
pytorch/test/test_appending_byte_serializer.py

99 lines
2.6 KiB
Python

# Owner(s): ["module: inductor"]
import dataclasses
from torch.testing._internal.common_utils import TestCase
from torch.utils._appending_byte_serializer import (
AppendingByteSerializer,
BytesReader,
BytesWriter,
)
class TestAppendingByteSerializer(TestCase):
def test_write_and_read_int(self) -> None:
def int_serializer(writer: BytesWriter, i: int) -> None:
writer.write_uint64(i)
def int_deserializer(reader: BytesReader) -> int:
return reader.read_uint64()
s = AppendingByteSerializer(serialize_fn=int_serializer)
data = [1, 2, 3, 4]
s.extend(data)
self.assertListEqual(
data,
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=int_deserializer
),
)
data2 = [8, 9, 10, 11]
s.extend(data2)
self.assertListEqual(
data + data2,
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=int_deserializer
),
)
def test_write_and_read_class(self) -> None:
@dataclasses.dataclass(frozen=True, eq=True)
class Foo:
x: int
y: str
z: bytes
@staticmethod
def serialize(writer: BytesWriter, cls: "Foo") -> None:
writer.write_uint64(cls.x)
writer.write_str(cls.y)
writer.write_bytes(cls.z)
@staticmethod
def deserialize(reader: BytesReader) -> "Foo":
x = reader.read_uint64()
y = reader.read_str()
z = reader.read_bytes()
return Foo(x, y, z)
a = Foo(5, "ok", bytes([15]))
b = Foo(10, "lol", bytes([25]))
s = AppendingByteSerializer(serialize_fn=Foo.serialize)
s.append(a)
self.assertListEqual(
[a],
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=Foo.deserialize
),
)
s.append(b)
self.assertListEqual(
[a, b],
AppendingByteSerializer.to_list(
s.to_bytes(), deserialize_fn=Foo.deserialize
),
)
def test_checksum(self) -> None:
writer = BytesWriter()
writer.write_str("test")
b = writer.to_bytes()
b = bytearray(b)
b[0:1] = b"\x00"
b = bytes(b)
with self.assertRaisesRegex(
RuntimeError, r"Bytes object is corrupted, checksum does not match.*"
):
BytesReader(b)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()