mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[dcp] Add ZStandard transformer (#143360)"
This reverts commit 7b56b039afe2b4a4038c09d8b6cb1597823f3d5f. Reverted https://github.com/pytorch/pytorch/pull/143360 on behalf of https://github.com/atalman due to Broke 3.13t builds please test with ciflow/binaries label attached ([comment](https://github.com/pytorch/pytorch/pull/143360#issuecomment-2603433066))
This commit is contained in:
@ -363,7 +363,6 @@ pwlf==2.2.1 ; python_version >= "3.8"
|
||||
astunparse
|
||||
PyYAML
|
||||
setuptools
|
||||
zstandard
|
||||
|
||||
ninja==1.11.1 ; platform_machine == "aarch64"
|
||||
scons==4.5.2 ; platform_machine == "aarch64"
|
||||
|
@ -19,4 +19,3 @@ setuptools
|
||||
sympy==1.13.3
|
||||
types-dataclasses
|
||||
typing-extensions>=4.10.0
|
||||
zstandard
|
||||
|
@ -8,7 +8,6 @@ from torch.distributed._tensor import (
|
||||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -59,7 +58,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
|
@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -166,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl()
|
||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||
@parametrize("extensions", [None, [Rot13Example()]])
|
||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||
paths = [tempfile.mkdtemp()]
|
||||
dist.broadcast_object_list(paths)
|
||||
|
@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
|
||||
save,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -192,39 +191,6 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadZStandard(TestCase):
|
||||
@parametrize("thread_count", _THREAD_COUNTS)
|
||||
def test_read_write_only_tensor(self, thread_count) -> None:
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
|
||||
|
||||
fs_writer = FileSystemWriter(
|
||||
path=path,
|
||||
thread_count=thread_count,
|
||||
_extensions=[ZStandard()],
|
||||
)
|
||||
save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=fs_writer,
|
||||
)
|
||||
|
||||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
state_dict_to_load_to["test_blob"] = BlobState(b"")
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
load(
|
||||
state_dict=state_dict_to_load_to,
|
||||
storage_reader=fs_reader,
|
||||
)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
@ -559,7 +525,6 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
|
||||
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,8 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import abc
|
||||
import importlib.util
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import IO, Type
|
||||
|
||||
@ -11,24 +9,7 @@ from typing import IO, Type
|
||||
# change. Feedback and bug fixes are always welcome.
|
||||
|
||||
|
||||
zstandard_module_name = "zstandard"
|
||||
|
||||
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
|
||||
pass
|
||||
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
|
||||
zstandard = importlib.util.module_from_spec(zstandard_spec)
|
||||
sys.modules[zstandard_module_name] = zstandard
|
||||
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
|
||||
else:
|
||||
zstandard = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Extension",
|
||||
"StreamTransformExtension",
|
||||
"ZStandard",
|
||||
"ExtensionRegistry",
|
||||
]
|
||||
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
|
||||
|
||||
|
||||
class Extension(abc.ABC):
|
||||
@ -91,50 +72,10 @@ class StreamTransformExtension(Extension):
|
||||
"""
|
||||
|
||||
|
||||
class ZStandard(StreamTransformExtension):
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return zstandard is not None
|
||||
|
||||
@staticmethod
|
||||
def from_descriptor(version: str) -> "ZStandard":
|
||||
if version.partition(".")[0] != "1":
|
||||
raise ValueError(f"Unknown extension {version=}")
|
||||
if not ZStandard.is_available():
|
||||
raise ValueError(
|
||||
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
|
||||
)
|
||||
return ZStandard()
|
||||
|
||||
@staticmethod
|
||||
def registry_name() -> str:
|
||||
return "stream.zstd"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
if not ZStandard.is_available():
|
||||
raise ValueError(
|
||||
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
|
||||
)
|
||||
|
||||
def get_descriptor(self) -> str:
|
||||
return f"{self.registry_name()}/1"
|
||||
|
||||
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
|
||||
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
|
||||
return compressor.stream_writer(output)
|
||||
|
||||
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
|
||||
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
|
||||
return decompressor.stream_reader(input)
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
def __init__(self) -> None:
|
||||
# Populate default registry contents
|
||||
self.extensions: dict[str, Type[Extension]] = {
|
||||
cls.registry_name(): cls for cls in (ZStandard,)
|
||||
}
|
||||
self.extensions: dict[str, Type[Extension]] = {}
|
||||
|
||||
def register(self, cls: Type[Extension]) -> None:
|
||||
self.extensions[cls.registry_name()] = cls
|
||||
|
Reference in New Issue
Block a user