[dcp] Add ZStandard transformer (#143360)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360
Approved by: https://github.com/saumishr, https://github.com/albanD
ghstack dependencies: #145528
This commit is contained in:
Marc Horowitz
2025-01-23 23:33:40 -08:00
committed by PyTorch MergeBot
parent f2ad2cdf1c
commit efebec5ef5
6 changed files with 160 additions and 8 deletions

View File

@ -362,6 +362,7 @@ pwlf==2.2.1 ; python_version >= "3.8"
# To build PyTorch itself
astunparse
PyYAML
pyzstd
setuptools
ninja==1.11.1 ; platform_machine == "aarch64"

View File

@ -8,6 +8,7 @@ from torch.distributed._tensor import (
Shard,
zeros,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
@parametrize("extensions", [None, [Rot13Example()]])
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
CHECKPOINT_DIR = self.temp_dir

View File

@ -21,6 +21,7 @@ from torch.distributed.checkpoint import (
load_state_dict,
save_state_dict,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -164,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
@parametrize("extensions", [None, [Rot13Example()]])
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()]
dist.broadcast_object_list(paths)

View File

@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
save,
save_state_dict,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.checkpoint.stateful import Stateful
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadZStandard(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_only_tensor(self, thread_count) -> None:
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
fs_writer = FileSystemWriter(
path=path,
thread_count=thread_count,
_extensions=[ZStandard()],
)
save(
state_dict=state_dict_to_save,
storage_writer=fs_writer,
)
state_dict_to_load_to = MyTestModule().state_dict()
state_dict_to_load_to["test_blob"] = BlobState(b"")
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
load(
state_dict=state_dict_to_load_to,
storage_reader=fs_reader,
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property
def world_size(self) -> int:
@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
if __name__ == "__main__":

View File

@ -1,15 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import abc
import io
from collections.abc import Sequence
from typing import IO, Type
from typing import cast, IO, Optional, Type
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer
from torch._utils import try_import
# NOTE: everything in this file is experimental, and subject to
# change. Feedback and bug fixes are always welcome.
pyzstd_module_name = "pyzstd"
pyzstd = try_import(pyzstd_module_name)
zstandard_module_name = "zstandard"
zstandard = try_import(zstandard_module_name)
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
__all__ = [
"Extension",
"StreamTransformExtension",
"ZStandard",
"ExtensionRegistry",
]
class Extension(abc.ABC):
@ -72,10 +88,110 @@ class StreamTransformExtension(Extension):
"""
class ZStandard(StreamTransformExtension):
@staticmethod
def is_available() -> bool:
return zstandard is not None or pyzstd is not None
@staticmethod
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
if not ZStandard.is_available():
raise ValueError(
f"Stream with ZStandard compression cannot be processed because "
f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
)
return ZStandard()
@staticmethod
def registry_name() -> str:
return "stream.zstd"
def __init__(self) -> None:
super().__init__()
if not ZStandard.is_available():
raise ValueError(
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
)
def get_descriptor(self) -> str:
return f"{self.registry_name()}/1"
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
if zstandard is not None:
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
return compressor.stream_writer(output)
class Writer(io.RawIOBase):
def __init__(self, output: IO[bytes]) -> None:
self.output = output
self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr]
def writeable(self) -> bool:
return True
def write(self, b: Buffer) -> Optional[int]:
outdata = self.compressor.compress(b)
if outdata:
self.output.write(outdata)
return len(memoryview(b))
def flush(self) -> None:
outdata = self.compressor.flush()
if outdata:
self.output.write(outdata)
self.output.flush()
return cast(IO[bytes], Writer(output))
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
if zstandard is not None:
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
return decompressor.stream_reader(input)
class Reader(io.RawIOBase):
def __init__(self, input: IO[bytes]) -> None:
self.input = input
self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr]
def readable(self) -> bool:
return True
def readinto(self, b: Buffer) -> Optional[int]:
# This needs to read enough so it can decompress
# something so the output doesn't look like EOF. This
# means reading at least one block. The max block
# size is 128KB, so we read that plus some
# overhead to be sure.
if self.decompressor.needs_input:
indata = self.input.read((128 + 6) * 1024)
else:
indata = b""
bview = memoryview(b)
blen = len(bview)
outdata = self.decompressor.decompress(indata, blen)
if outdata is None:
return None
count = len(outdata)
bview[:count] = outdata
return count
def seekable(self) -> bool:
return False
return cast(IO[bytes], Reader(input))
class ExtensionRegistry:
def __init__(self) -> None:
# Populate default registry contents
self.extensions: dict[str, Type[Extension]] = {}
self.extensions: dict[str, Type[Extension]] = {
cls.registry_name(): cls for cls in (ZStandard,)
}
def register(self, cls: Type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls

View File

@ -88,9 +88,7 @@ class Rot13Example(StreamTransformExtension):
# It's possible self.input is an IO[bytes] with no readinto method.
# In that case, we emulate with a read and copy. In practice,
# all of the current concrete extensions have readinto.
# 0 as a flags value is janky, but the flag values aren't available
# in python until 3.12.
view = b.__buffer__(0)
view = memoryview(b)
r = self.input.read(len(view))
if r is None:
count = None