diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index a936b2b043fd..ecce41f8ea73 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -363,7 +363,6 @@ pwlf==2.2.1 ; python_version >= "3.8" astunparse PyYAML setuptools -zstandard ninja==1.11.1 ; platform_machine == "aarch64" scons==4.5.2 ; platform_machine == "aarch64" diff --git a/requirements.txt b/requirements.txt index 64d5563c2d37..03731d2e0fc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,3 @@ setuptools sympy==1.13.3 types-dataclasses typing-extensions>=4.10.0 -zstandard diff --git a/test/distributed/checkpoint/test_dtensor_resharding.py b/test/distributed/checkpoint/test_dtensor_resharding.py index b99e6592c5cc..f4e982c3c46f 100644 --- a/test/distributed/checkpoint/test_dtensor_resharding.py +++ b/test/distributed/checkpoint/test_dtensor_resharding.py @@ -8,7 +8,6 @@ from torch.distributed._tensor import ( Shard, zeros, ) -from torch.distributed.checkpoint._extension import ZStandard from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -59,7 +58,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(2) @with_temp_dir - @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) + @parametrize("extensions", [None, [Rot13Example()]]) def test_1d_to_1d_reshard_placement_change(self, extensions) -> None: CHECKPOINT_DIR = self.temp_dir diff --git a/test/distributed/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py index c7c6e88b1684..dbfcef0c2f34 100644 --- a/test/distributed/checkpoint/test_file_system_checkpoint.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint.py @@ -22,7 +22,6 @@ from torch.distributed.checkpoint import ( load_state_dict, save_state_dict, ) -from torch.distributed.checkpoint._extension import ZStandard from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -166,7 +165,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) @skip_if_lt_x_gpu(2) @requires_nccl() - @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) + @parametrize("extensions", [None, [Rot13Example()]]) def test_read_write_shard_tensor(self, extensions) -> None: paths = [tempfile.mkdtemp()] dist.broadcast_object_list(paths) diff --git a/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py index f2e1483ce3d3..a398a55cdb62 100644 --- a/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py @@ -22,7 +22,6 @@ from torch.distributed.checkpoint import ( save, save_state_dict, ) -from torch.distributed.checkpoint._extension import ZStandard from torch.distributed.checkpoint.stateful import Stateful from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -192,39 +191,6 @@ class TestDistributedStateDictSaveLoadRot13(TestCase): assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) -class TestDistributedStateDictSaveLoadZStandard(TestCase): - @parametrize("thread_count", _THREAD_COUNTS) - def test_read_write_only_tensor(self, thread_count) -> None: - with tempfile.TemporaryDirectory() as path: - state_dict_to_save = MyTestModule().state_dict() - state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting") - - fs_writer = FileSystemWriter( - path=path, - thread_count=thread_count, - _extensions=[ZStandard()], - ) - save( - state_dict=state_dict_to_save, - storage_writer=fs_writer, - ) - - state_dict_to_load_to = MyTestModule().state_dict() - state_dict_to_load_to["test_blob"] = BlobState(b"") - - with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) - - # Load from file without any resharding - fs_reader = FileSystemReader(path=path) - load( - state_dict=state_dict_to_load_to, - storage_reader=fs_reader, - ) - - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) - - class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @property def world_size(self) -> int: @@ -559,7 +525,6 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase): instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13) instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) -instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard) instantiate_parametrized_tests(TestDistributedReshardOnLoad) if __name__ == "__main__": diff --git a/torch/distributed/checkpoint/_extension.py b/torch/distributed/checkpoint/_extension.py index f03d805c28ce..0cadfbc8c529 100644 --- a/torch/distributed/checkpoint/_extension.py +++ b/torch/distributed/checkpoint/_extension.py @@ -1,8 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import abc -import importlib.util -import sys from collections.abc import Sequence from typing import IO, Type @@ -11,24 +9,7 @@ from typing import IO, Type # change. Feedback and bug fixes are always welcome. -zstandard_module_name = "zstandard" - -if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None: - pass -elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None: - zstandard = importlib.util.module_from_spec(zstandard_spec) - sys.modules[zstandard_module_name] = zstandard - zstandard_spec.loader.exec_module(zstandard) # type: ignore[union-attr] -else: - zstandard = None - - -__all__ = [ - "Extension", - "StreamTransformExtension", - "ZStandard", - "ExtensionRegistry", -] +__all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"] class Extension(abc.ABC): @@ -91,50 +72,10 @@ class StreamTransformExtension(Extension): """ -class ZStandard(StreamTransformExtension): - @staticmethod - def is_available() -> bool: - return zstandard is not None - - @staticmethod - def from_descriptor(version: str) -> "ZStandard": - if version.partition(".")[0] != "1": - raise ValueError(f"Unknown extension {version=}") - if not ZStandard.is_available(): - raise ValueError( - f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'" - ) - return ZStandard() - - @staticmethod - def registry_name() -> str: - return "stream.zstd" - - def __init__(self) -> None: - super().__init__() - if not ZStandard.is_available(): - raise ValueError( - f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'" - ) - - def get_descriptor(self) -> str: - return f"{self.registry_name()}/1" - - def transform_to(self, output: IO[bytes]) -> IO[bytes]: - compressor = zstandard.ZstdCompressor() # type: ignore[union-attr] - return compressor.stream_writer(output) - - def transform_from(self, input: IO[bytes]) -> IO[bytes]: - decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr] - return decompressor.stream_reader(input) - - class ExtensionRegistry: def __init__(self) -> None: # Populate default registry contents - self.extensions: dict[str, Type[Extension]] = { - cls.registry_name(): cls for cls in (ZStandard,) - } + self.extensions: dict[str, Type[Extension]] = {} def register(self, cls: Type[Extension]) -> None: self.extensions[cls.registry_name()] = cls