mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152139 Approved by: https://github.com/zou3519
99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import dataclasses
|
|
|
|
from torch.testing._internal.common_utils import TestCase
|
|
from torch.utils._appending_byte_serializer import (
|
|
AppendingByteSerializer,
|
|
BytesReader,
|
|
BytesWriter,
|
|
)
|
|
|
|
|
|
class TestAppendingByteSerializer(TestCase):
|
|
def test_write_and_read_int(self) -> None:
|
|
def int_serializer(writer: BytesWriter, i: int) -> None:
|
|
writer.write_uint64(i)
|
|
|
|
def int_deserializer(reader: BytesReader) -> int:
|
|
return reader.read_uint64()
|
|
|
|
s = AppendingByteSerializer(serialize_fn=int_serializer)
|
|
|
|
data = [1, 2, 3, 4]
|
|
s.extend(data)
|
|
self.assertListEqual(
|
|
data,
|
|
AppendingByteSerializer.to_list(
|
|
s.to_bytes(), deserialize_fn=int_deserializer
|
|
),
|
|
)
|
|
|
|
data2 = [8, 9, 10, 11]
|
|
s.extend(data2)
|
|
self.assertListEqual(
|
|
data + data2,
|
|
AppendingByteSerializer.to_list(
|
|
s.to_bytes(), deserialize_fn=int_deserializer
|
|
),
|
|
)
|
|
|
|
def test_write_and_read_class(self) -> None:
|
|
@dataclasses.dataclass(frozen=True, eq=True)
|
|
class Foo:
|
|
x: int
|
|
y: str
|
|
z: bytes
|
|
|
|
@staticmethod
|
|
def serialize(writer: BytesWriter, cls: "Foo") -> None:
|
|
writer.write_uint64(cls.x)
|
|
writer.write_str(cls.y)
|
|
writer.write_bytes(cls.z)
|
|
|
|
@staticmethod
|
|
def deserialize(reader: BytesReader) -> "Foo":
|
|
x = reader.read_uint64()
|
|
y = reader.read_str()
|
|
z = reader.read_bytes()
|
|
return Foo(x, y, z)
|
|
|
|
a = Foo(5, "ok", bytes([15]))
|
|
b = Foo(10, "lol", bytes([25]))
|
|
|
|
s = AppendingByteSerializer(serialize_fn=Foo.serialize)
|
|
s.append(a)
|
|
self.assertListEqual(
|
|
[a],
|
|
AppendingByteSerializer.to_list(
|
|
s.to_bytes(), deserialize_fn=Foo.deserialize
|
|
),
|
|
)
|
|
|
|
s.append(b)
|
|
self.assertListEqual(
|
|
[a, b],
|
|
AppendingByteSerializer.to_list(
|
|
s.to_bytes(), deserialize_fn=Foo.deserialize
|
|
),
|
|
)
|
|
|
|
def test_checksum(self) -> None:
|
|
writer = BytesWriter()
|
|
writer.write_str("test")
|
|
b = writer.to_bytes()
|
|
b = bytearray(b)
|
|
b[0:1] = b"\x00"
|
|
b = bytes(b)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, r"Bytes object is corrupted, checksum does not match.*"
|
|
):
|
|
BytesReader(b)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|