[dcp] Add ZStandard transformer (#143360)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360
Approved by: https://github.com/saumishr
ghstack dependencies: #143358, #143359
This commit is contained in:
Marc Horowitz
2025-01-15 18:54:44 -08:00
committed by PyTorch MergeBot
parent 9c909bf3bb
commit 7b56b039af
6 changed files with 102 additions and 4 deletions

View File

@ -363,6 +363,7 @@ pwlf==2.2.1 ; python_version >= "3.8"
astunparse
PyYAML
setuptools
zstandard
ninja==1.11.1 ; platform_machine == "aarch64"
scons==4.5.2 ; platform_machine == "aarch64"

View File

@ -19,3 +19,4 @@ setuptools
sympy==1.13.3
types-dataclasses
typing-extensions>=4.10.0
zstandard

View File

@ -8,6 +8,7 @@ from torch.distributed._tensor import (
Shard,
zeros,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
@parametrize("extensions", [None, [Rot13Example()]])
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
CHECKPOINT_DIR = self.temp_dir

View File

@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
load_state_dict,
save_state_dict,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -165,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
@parametrize("extensions", [None, [Rot13Example()]])
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()]
dist.broadcast_object_list(paths)

View File

@ -22,6 +22,7 @@ from torch.distributed.checkpoint import (
save,
save_state_dict,
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.checkpoint.stateful import Stateful
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadZStandard(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_only_tensor(self, thread_count) -> None:
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
fs_writer = FileSystemWriter(
path=path,
thread_count=thread_count,
_extensions=[ZStandard()],
)
save(
state_dict=state_dict_to_save,
storage_writer=fs_writer,
)
state_dict_to_load_to = MyTestModule().state_dict()
state_dict_to_load_to["test_blob"] = BlobState(b"")
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
load(
state_dict=state_dict_to_load_to,
storage_reader=fs_reader,
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property
def world_size(self) -> int:
@ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
if __name__ == "__main__":

View File

@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import abc
import importlib.util
import sys
from collections.abc import Sequence
from typing import IO, Type
@ -9,7 +11,24 @@ from typing import IO, Type
# change. Feedback and bug fixes are always welcome.
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
zstandard_module_name = "zstandard"
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
pass
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
zstandard = importlib.util.module_from_spec(zstandard_spec)
sys.modules[zstandard_module_name] = zstandard
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
else:
zstandard = None
__all__ = [
"Extension",
"StreamTransformExtension",
"ZStandard",
"ExtensionRegistry",
]
class Extension(abc.ABC):
@ -72,10 +91,50 @@ class StreamTransformExtension(Extension):
"""
class ZStandard(StreamTransformExtension):
@staticmethod
def is_available() -> bool:
return zstandard is not None
@staticmethod
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
if not ZStandard.is_available():
raise ValueError(
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
)
return ZStandard()
@staticmethod
def registry_name() -> str:
return "stream.zstd"
def __init__(self) -> None:
super().__init__()
if not ZStandard.is_available():
raise ValueError(
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
)
def get_descriptor(self) -> str:
return f"{self.registry_name()}/1"
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
return compressor.stream_writer(output)
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
return decompressor.stream_reader(input)
class ExtensionRegistry:
def __init__(self) -> None:
# Populate default registry contents
self.extensions: dict[str, Type[Extension]] = {}
self.extensions: dict[str, Type[Extension]] = {
cls.registry_name(): cls for cls in (ZStandard,)
}
def register(self, cls: Type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls