Revert "[dcp] Add ZStandard transformer (#143360)"

This reverts commit 7b56b039afe2b4a4038c09d8b6cb1597823f3d5f.

Reverted https://github.com/pytorch/pytorch/pull/143360 on behalf of https://github.com/atalman due to Broke 3.13t builds please test with ciflow/binaries label attached ([comment](https://github.com/pytorch/pytorch/pull/143360#issuecomment-2603433066))
This commit is contained in:
PyTorch MergeBot
2025-01-21 01:10:16 +00:00
parent 5fd881a5b6
commit c6986ca2e1
6 changed files with 4 additions and 102 deletions

View File

@ -363,7 +363,6 @@ pwlf==2.2.1 ; python_version >= "3.8"
astunparse astunparse
PyYAML PyYAML
setuptools setuptools
zstandard
ninja==1.11.1 ; platform_machine == "aarch64" ninja==1.11.1 ; platform_machine == "aarch64"
scons==4.5.2 ; platform_machine == "aarch64" scons==4.5.2 ; platform_machine == "aarch64"

View File

@ -19,4 +19,3 @@ setuptools
sympy==1.13.3 sympy==1.13.3
types-dataclasses types-dataclasses
typing-extensions>=4.10.0 typing-extensions>=4.10.0
zstandard

View File

@ -8,7 +8,6 @@ from torch.distributed._tensor import (
Shard, Shard,
zeros, zeros,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
@ -59,7 +58,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase):
@with_comms @with_comms
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@with_temp_dir @with_temp_dir
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) @parametrize("extensions", [None, [Rot13Example()]])
def test_1d_to_1d_reshard_placement_change(self, extensions) -> None: def test_1d_to_1d_reshard_placement_change(self, extensions) -> None:
CHECKPOINT_DIR = self.temp_dir CHECKPOINT_DIR = self.temp_dir

View File

@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
load_state_dict, load_state_dict,
save_state_dict, save_state_dict,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -166,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False) @with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
@requires_nccl() @requires_nccl()
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) @parametrize("extensions", [None, [Rot13Example()]])
def test_read_write_shard_tensor(self, extensions) -> None: def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()] paths = [tempfile.mkdtemp()]
dist.broadcast_object_list(paths) dist.broadcast_object_list(paths)

View File

@ -22,7 +22,6 @@ from torch.distributed.checkpoint import (
save, save,
save_state_dict, save_state_dict,
) )
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
@ -192,39 +191,6 @@ class TestDistributedStateDictSaveLoadRot13(TestCase):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadZStandard(TestCase):
@parametrize("thread_count", _THREAD_COUNTS)
def test_read_write_only_tensor(self, thread_count) -> None:
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting")
fs_writer = FileSystemWriter(
path=path,
thread_count=thread_count,
_extensions=[ZStandard()],
)
save(
state_dict=state_dict_to_save,
storage_writer=fs_writer,
)
state_dict_to_load_to = MyTestModule().state_dict()
state_dict_to_load_to["test_blob"] = BlobState(b"")
with self.assertRaises(AssertionError):
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
load(
state_dict=state_dict_to_load_to,
storage_reader=fs_reader,
)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
@ -559,7 +525,6 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) instantiate_parametrized_tests(TestDistributedStateDictSaveLoad)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor)
instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard)
instantiate_parametrized_tests(TestDistributedReshardOnLoad) instantiate_parametrized_tests(TestDistributedReshardOnLoad)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
import abc import abc
import importlib.util
import sys
from collections.abc import Sequence from collections.abc import Sequence
from typing import IO, Type from typing import IO, Type
@ -11,24 +9,7 @@ from typing import IO, Type
# change. Feedback and bug fixes are always welcome. # change. Feedback and bug fixes are always welcome.
zstandard_module_name = "zstandard" __all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"]
if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None:
pass
elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None:
zstandard = importlib.util.module_from_spec(zstandard_spec)
sys.modules[zstandard_module_name] = zstandard
zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr]
else:
zstandard = None
__all__ = [
"Extension",
"StreamTransformExtension",
"ZStandard",
"ExtensionRegistry",
]
class Extension(abc.ABC): class Extension(abc.ABC):
@ -91,50 +72,10 @@ class StreamTransformExtension(Extension):
""" """
class ZStandard(StreamTransformExtension):
@staticmethod
def is_available() -> bool:
return zstandard is not None
@staticmethod
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
if not ZStandard.is_available():
raise ValueError(
f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'"
)
return ZStandard()
@staticmethod
def registry_name() -> str:
return "stream.zstd"
def __init__(self) -> None:
super().__init__()
if not ZStandard.is_available():
raise ValueError(
f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'"
)
def get_descriptor(self) -> str:
return f"{self.registry_name()}/1"
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
return compressor.stream_writer(output)
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
return decompressor.stream_reader(input)
class ExtensionRegistry: class ExtensionRegistry:
def __init__(self) -> None: def __init__(self) -> None:
# Populate default registry contents # Populate default registry contents
self.extensions: dict[str, Type[Extension]] = { self.extensions: dict[str, Type[Extension]] = {}
cls.registry_name(): cls for cls in (ZStandard,)
}
def register(self, cls: Type[Extension]) -> None: def register(self, cls: Type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls self.extensions[cls.registry_name()] = cls