mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dcp] Add ZStandard transformer (#143360)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360 Approved by: https://github.com/saumishr, https://github.com/albanD ghstack dependencies: #145528
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2ad2cdf1c
commit
efebec5ef5
@ -362,6 +362,7 @@ pwlf==2.2.1 ; python_version >= "3.8"
|
||||
# To build PyTorch itself
|
||||
astunparse
|
||||
PyYAML
|
||||
pyzstd
|
||||
setuptools
|
||||
|
||||
ninja==1.11.1 ; platform_machine == "aarch64"
|
||||
|
@ -8,6 +8,7 @@ from torch.distributed._tensor import (
|
||||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
|
@ -21,6 +21,7 @@ from torch.distributed.checkpoint import (
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -164,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||
paths = [tempfile.mkdtemp()]
|
||||
dist.broadcast_object_list(paths)
|
||||
|
@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
|
||||
save,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadZStandard(TestCase):
|
||||
@parametrize("thread_count", _THREAD_COUNTS)
|
||||
def test_read_write_only_tensor(self, thread_count) -> None:
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
|
||||
|
||||
fs_writer = FileSystemWriter(
|
||||
path=path,
|
||||
thread_count=thread_count,
|
||||
_extensions=[ZStandard()],
|
||||
)
|
||||
save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=fs_writer,
|
||||
)
|
||||
|
||||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
state_dict_to_load_to["test_blob"] = BlobState(b"")
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
load(
|
||||
state_dict=state_dict_to_load_to,
|
||||
storage_reader=fs_reader,
|
||||
)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
|
||||
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,15 +1,31 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import abc
|
||||
import io
|
||||
from collections.abc import Sequence
|
||||
from typing import IO, Type
|
||||
from typing import cast, IO, Optional, Type
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
|
||||
from torch._utils import try_import
|
||||
|
||||
|
||||
# NOTE: everything in this file is experimental, and subject to
|
||||
# change. Feedback and bug fixes are always welcome.
|
||||
|
||||
pyzstd_module_name = "pyzstd"
|
||||
pyzstd = try_import(pyzstd_module_name)
|
||||
zstandard_module_name = "zstandard"
|
||||
zstandard = try_import(zstandard_module_name)
|
||||
|
||||
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
|
||||
|
||||
__all__ = [
|
||||
"Extension",
|
||||
"StreamTransformExtension",
|
||||
"ZStandard",
|
||||
"ExtensionRegistry",
|
||||
]
|
||||
|
||||
|
||||
class Extension(abc.ABC):
|
||||
@ -72,10 +88,110 @@ class StreamTransformExtension(Extension):
|
||||
"""
|
||||
|
||||
|
||||
class ZStandard(StreamTransformExtension):
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return zstandard is not None or pyzstd is not None
|
||||
|
||||
@staticmethod
|
||||
def from_descriptor(version: str) -> "ZStandard":
|
||||
if version.partition(".")[0] != "1":
|
||||
raise ValueError(f"Unknown extension {version=}")
|
||||
if not ZStandard.is_available():
|
||||
raise ValueError(
|
||||
f"Stream with ZStandard compression cannot be processed because "
|
||||
f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
|
||||
)
|
||||
return ZStandard()
|
||||
|
||||
@staticmethod
|
||||
def registry_name() -> str:
|
||||
return "stream.zstd"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if not ZStandard.is_available():
|
||||
raise ValueError(
|
||||
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
|
||||
)
|
||||
|
||||
def get_descriptor(self) -> str:
|
||||
return f"{self.registry_name()}/1"
|
||||
|
||||
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
|
||||
if zstandard is not None:
|
||||
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
|
||||
return compressor.stream_writer(output)
|
||||
|
||||
class Writer(io.RawIOBase):
|
||||
def __init__(self, output: IO[bytes]) -> None:
|
||||
self.output = output
|
||||
self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr]
|
||||
|
||||
def writeable(self) -> bool:
|
||||
return True
|
||||
|
||||
def write(self, b: Buffer) -> Optional[int]:
|
||||
outdata = self.compressor.compress(b)
|
||||
if outdata:
|
||||
self.output.write(outdata)
|
||||
return len(memoryview(b))
|
||||
|
||||
def flush(self) -> None:
|
||||
outdata = self.compressor.flush()
|
||||
if outdata:
|
||||
self.output.write(outdata)
|
||||
self.output.flush()
|
||||
|
||||
return cast(IO[bytes], Writer(output))
|
||||
|
||||
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
|
||||
if zstandard is not None:
|
||||
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
|
||||
return decompressor.stream_reader(input)
|
||||
|
||||
class Reader(io.RawIOBase):
|
||||
def __init__(self, input: IO[bytes]) -> None:
|
||||
self.input = input
|
||||
self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr]
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def readinto(self, b: Buffer) -> Optional[int]:
|
||||
# This needs to read enough so it can decompress
|
||||
# something so the output doesn't look like EOF. This
|
||||
# means reading at least one block. The max block
|
||||
# size is 128KB, so we read that plus some
|
||||
# overhead to be sure.
|
||||
|
||||
if self.decompressor.needs_input:
|
||||
indata = self.input.read((128 + 6) * 1024)
|
||||
else:
|
||||
indata = b""
|
||||
|
||||
bview = memoryview(b)
|
||||
blen = len(bview)
|
||||
outdata = self.decompressor.decompress(indata, blen)
|
||||
if outdata is None:
|
||||
return None
|
||||
|
||||
count = len(outdata)
|
||||
bview[:count] = outdata
|
||||
return count
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
return cast(IO[bytes], Reader(input))
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
def __init__(self) -> None:
|
||||
# Populate default registry contents
|
||||
self.extensions: dict[str, Type[Extension]] = {}
|
||||
self.extensions: dict[str, Type[Extension]] = {
|
||||
cls.registry_name(): cls for cls in (ZStandard,)
|
||||
}
|
||||
|
||||
def register(self, cls: Type[Extension]) -> None:
|
||||
self.extensions[cls.registry_name()] = cls
|
||||
|
@ -88,9 +88,7 @@ class Rot13Example(StreamTransformExtension):
|
||||
# It's possible self.input is an IO[bytes] with no readinto method.
|
||||
# In that case, we emulate with a read and copy. In practice,
|
||||
# all of the current concrete extensions have readinto.
|
||||
# 0 as a flags value is janky, but the flag values aren't available
|
||||
# in python until 3.12.
|
||||
view = b.__buffer__(0)
|
||||
view = memoryview(b)
|
||||
r = self.input.read(len(view))
|
||||
if r is None:
|
||||
count = None
|
||||
|
Reference in New Issue
Block a user