mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "[dcp] Add ZStandard transformer (#143360)"
This reverts commit 7b56b039afe2b4a4038c09d8b6cb1597823f3d5f. Reverted https://github.com/pytorch/pytorch/pull/143360 on behalf of https://github.com/atalman due to Broke 3.13t builds please test with ciflow/binaries label attached ([comment](https://github.com/pytorch/pytorch/pull/143360#issuecomment-2603433066))
This commit is contained in:
@ -363,7 +363,6 @@ pwlf==2.2.1 ; python_version >= "3.8"
|
|||||||
astunparse
|
astunparse
|
||||||
PyYAML
|
PyYAML
|
||||||
setuptools
|
setuptools
|
||||||
zstandard
|
|
||||||
|
|
||||||
ninja==1.11.1 ; platform_machine == "aarch64"
|
ninja==1.11.1 ; platform_machine == "aarch64"
|
||||||
scons==4.5.2 ; platform_machine == "aarch64"
|
scons==4.5.2 ; platform_machine == "aarch64"
|
||||||
|
@ -19,4 +19,3 @@ setuptools
|
|||||||
sympy==1.13.3
|
sympy==1.13.3
|
||||||
types-dataclasses
|
types-dataclasses
|
||||||
typing-extensions>=4.10.0
|
typing-extensions>=4.10.0
|
||||||
zstandard
|
|
||||||
|
@ -8,7 +8,6 @@ from torch.distributed._tensor import (
|
|||||||
Shard,
|
Shard,
|
||||||
zeros,
|
zeros,
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint._extension import ZStandard
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
@ -59,7 +58,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
|
|||||||
@with_comms
|
@with_comms
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
@parametrize("extensions", [None, [Rot13Example()]])
|
||||||
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
|
||||||
CHECKPOINT_DIR = self.temp_dir
|
CHECKPOINT_DIR = self.temp_dir
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
|
|||||||
load_state_dict,
|
load_state_dict,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint._extension import ZStandard
|
|
||||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
@ -166,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
|||||||
@with_comms(init_rpc=False)
|
@with_comms(init_rpc=False)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
@parametrize("extensions", [None, [Rot13Example()]])
|
||||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||||
paths = [tempfile.mkdtemp()]
|
paths = [tempfile.mkdtemp()]
|
||||||
dist.broadcast_object_list(paths)
|
dist.broadcast_object_list(paths)
|
||||||
|
@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
|
|||||||
save,
|
save,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
)
|
)
|
||||||
from torch.distributed.checkpoint._extension import ZStandard
|
|
||||||
from torch.distributed.checkpoint.stateful import Stateful
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
@ -192,39 +191,6 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
|
|||||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||||
|
|
||||||
|
|
||||||
class TestDistributedStateDictSaveLoadZStandard(TestCase):
|
|
||||||
@parametrize("thread_count", _THREAD_COUNTS)
|
|
||||||
def test_read_write_only_tensor(self, thread_count) -> None:
|
|
||||||
with tempfile.TemporaryDirectory() as path:
|
|
||||||
state_dict_to_save = MyTestModule().state_dict()
|
|
||||||
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
|
|
||||||
|
|
||||||
fs_writer = FileSystemWriter(
|
|
||||||
path=path,
|
|
||||||
thread_count=thread_count,
|
|
||||||
_extensions=[ZStandard()],
|
|
||||||
)
|
|
||||||
save(
|
|
||||||
state_dict=state_dict_to_save,
|
|
||||||
storage_writer=fs_writer,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict_to_load_to = MyTestModule().state_dict()
|
|
||||||
state_dict_to_load_to["test_blob"] = BlobState(b"")
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
|
||||||
|
|
||||||
# Load from file without any resharding
|
|
||||||
fs_reader = FileSystemReader(path=path)
|
|
||||||
load(
|
|
||||||
state_dict=state_dict_to_load_to,
|
|
||||||
storage_reader=fs_reader,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
@ -559,7 +525,6 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
|
||||||
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
|
|
||||||
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
instantiate_parametrized_tests(TestDistributedReshardOnLoad)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib.util
|
|
||||||
import sys
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import IO, Type
|
from typing import IO, Type
|
||||||
|
|
||||||
@ -11,24 +9,7 @@ from typing import IO, Type
|
|||||||
# change. Feedback and bug fixes are always welcome.
|
# change. Feedback and bug fixes are always welcome.
|
||||||
|
|
||||||
|
|
||||||
zstandard_module_name = "zstandard"
|
__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
|
||||||
|
|
||||||
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
|
|
||||||
pass
|
|
||||||
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
|
|
||||||
zstandard = importlib.util.module_from_spec(zstandard_spec)
|
|
||||||
sys.modules[zstandard_module_name] = zstandard
|
|
||||||
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
|
|
||||||
else:
|
|
||||||
zstandard = None
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Extension",
|
|
||||||
"StreamTransformExtension",
|
|
||||||
"ZStandard",
|
|
||||||
"ExtensionRegistry",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Extension(abc.ABC):
|
class Extension(abc.ABC):
|
||||||
@ -91,50 +72,10 @@ class StreamTransformExtension(Extension):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ZStandard(StreamTransformExtension):
|
|
||||||
@staticmethod
|
|
||||||
def is_available() -> bool:
|
|
||||||
return zstandard is not None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_descriptor(version: str) -> "ZStandard":
|
|
||||||
if version.partition(".")[0] != "1":
|
|
||||||
raise ValueError(f"Unknown extension {version=}")
|
|
||||||
if not ZStandard.is_available():
|
|
||||||
raise ValueError(
|
|
||||||
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
|
|
||||||
)
|
|
||||||
return ZStandard()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def registry_name() -> str:
|
|
||||||
return "stream.zstd"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
if not ZStandard.is_available():
|
|
||||||
raise ValueError(
|
|
||||||
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_descriptor(self) -> str:
|
|
||||||
return f"{self.registry_name()}/1"
|
|
||||||
|
|
||||||
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
|
|
||||||
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
|
|
||||||
return compressor.stream_writer(output)
|
|
||||||
|
|
||||||
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
|
|
||||||
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
|
|
||||||
return decompressor.stream_reader(input)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionRegistry:
|
class ExtensionRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# Populate default registry contents
|
# Populate default registry contents
|
||||||
self.extensions: dict[str, Type[Extension]] = {
|
self.extensions: dict[str, Type[Extension]] = {}
|
||||||
cls.registry_name(): cls for cls in (ZStandard,)
|
|
||||||
}
|
|
||||||
|
|
||||||
def register(self, cls: Type[Extension]) -> None:
|
def register(self, cls: Type[Extension]) -> None:
|
||||||
self.extensions[cls.registry_name()] = cls
|
self.extensions[cls.registry_name()] = cls
|
||||||
|
Reference in New Issue
Block a user