mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[dcp] Add ZStandard transformer (#143360)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143360 Approved by: https://github.com/saumishr ghstack dependencies: #143358, #143359
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							9c909bf3bb
						
					
				
				
					commit
					7b56b039af
				
			| @ -363,6 +363,7 @@ pwlf==2.2.1 ; python_version >= "3.8" | ||||
| astunparse | ||||
| PyYAML | ||||
| setuptools | ||||
| zstandard | ||||
|  | ||||
| ninja==1.11.1 ; platform_machine == "aarch64" | ||||
| scons==4.5.2 ; platform_machine == "aarch64" | ||||
|  | ||||
| @ -19,3 +19,4 @@ setuptools | ||||
| sympy==1.13.3 | ||||
| types-dataclasses | ||||
| typing-extensions>=4.10.0 | ||||
| zstandard | ||||
|  | ||||
| @ -8,6 +8,7 @@ from torch.distributed._tensor import ( | ||||
|     Shard, | ||||
|     zeros, | ||||
| ) | ||||
| from torch.distributed.checkpoint._extension import ZStandard | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
|     parametrize, | ||||
| @ -58,7 +59,7 @@ class TestDTensorReshardPlacementChange(DTensorTestBase): | ||||
|     @with_comms | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @with_temp_dir | ||||
|     @parametrize("extensions", [None, [Rot13Example()]]) | ||||
|     @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) | ||||
|     def test_1d_to_1d_reshard_placement_change(self, extensions) -> None: | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|  | ||||
| @ -22,6 +22,7 @@ from torch.distributed.checkpoint import ( | ||||
|     load_state_dict, | ||||
|     save_state_dict, | ||||
| ) | ||||
| from torch.distributed.checkpoint._extension import ZStandard | ||||
| from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
| @ -165,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): | ||||
|     @with_comms(init_rpc=False) | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @requires_nccl() | ||||
|     @parametrize("extensions", [None, [Rot13Example()]]) | ||||
|     @parametrize("extensions", [None, [Rot13Example()], [ZStandard()]]) | ||||
|     def test_read_write_shard_tensor(self, extensions) -> None: | ||||
|         paths = [tempfile.mkdtemp()] | ||||
|         dist.broadcast_object_list(paths) | ||||
|  | ||||
| @ -22,6 +22,7 @@ from torch.distributed.checkpoint import ( | ||||
|     save, | ||||
|     save_state_dict, | ||||
| ) | ||||
| from torch.distributed.checkpoint._extension import ZStandard | ||||
| from torch.distributed.checkpoint.stateful import Stateful | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
| @ -191,6 +192,39 @@ class TestDistributedStateDictSaveLoadRot13(TestCase): | ||||
|             assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) | ||||
|  | ||||
|  | ||||
| class TestDistributedStateDictSaveLoadZStandard(TestCase): | ||||
|     @parametrize("thread_count", _THREAD_COUNTS) | ||||
|     def test_read_write_only_tensor(self, thread_count) -> None: | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             state_dict_to_save = MyTestModule().state_dict() | ||||
|             state_dict_to_save["test_blob"] = BlobState(b"SomeBlobForTesting") | ||||
|  | ||||
|             fs_writer = FileSystemWriter( | ||||
|                 path=path, | ||||
|                 thread_count=thread_count, | ||||
|                 _extensions=[ZStandard()], | ||||
|             ) | ||||
|             save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=fs_writer, | ||||
|             ) | ||||
|  | ||||
|             state_dict_to_load_to = MyTestModule().state_dict() | ||||
|             state_dict_to_load_to["test_blob"] = BlobState(b"") | ||||
|  | ||||
|             with self.assertRaises(AssertionError): | ||||
|                 assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) | ||||
|  | ||||
|             # Load from file without any resharding | ||||
|             fs_reader = FileSystemReader(path=path) | ||||
|             load( | ||||
|                 state_dict=state_dict_to_load_to, | ||||
|                 storage_reader=fs_reader, | ||||
|             ) | ||||
|  | ||||
|             assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) | ||||
|  | ||||
|  | ||||
| class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
| @ -525,6 +559,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase): | ||||
| instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) | ||||
| instantiate_parametrized_tests(TestDistributedStateDictSaveLoadRot13) | ||||
| instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) | ||||
| instantiate_parametrized_tests(TestDistributedStateDictSaveLoadZStandard) | ||||
| instantiate_parametrized_tests(TestDistributedReshardOnLoad) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates | ||||
|  | ||||
| import abc | ||||
| import importlib.util | ||||
| import sys | ||||
| from collections.abc import Sequence | ||||
| from typing import IO, Type | ||||
|  | ||||
| @ -9,7 +11,24 @@ from typing import IO, Type | ||||
| # change.  Feedback and bug fixes are always welcome. | ||||
|  | ||||
|  | ||||
| __all__ = ["Extension", "StreamTransformExtension", "ExtensionRegistry"] | ||||
| zstandard_module_name = "zstandard" | ||||
|  | ||||
| if (zstandard := sys.modules.get(zstandard_module_name, None)) is not None: | ||||
|     pass | ||||
| elif (zstandard_spec := importlib.util.find_spec(zstandard_module_name)) is not None: | ||||
|     zstandard = importlib.util.module_from_spec(zstandard_spec) | ||||
|     sys.modules[zstandard_module_name] = zstandard | ||||
|     zstandard_spec.loader.exec_module(zstandard)  # type: ignore[union-attr] | ||||
| else: | ||||
|     zstandard = None | ||||
|  | ||||
|  | ||||
| __all__ = [ | ||||
|     "Extension", | ||||
|     "StreamTransformExtension", | ||||
|     "ZStandard", | ||||
|     "ExtensionRegistry", | ||||
| ] | ||||
|  | ||||
|  | ||||
| class Extension(abc.ABC): | ||||
| @ -72,10 +91,50 @@ class StreamTransformExtension(Extension): | ||||
|         """ | ||||
|  | ||||
|  | ||||
| class ZStandard(StreamTransformExtension): | ||||
|     @staticmethod | ||||
|     def is_available() -> bool: | ||||
|         return zstandard is not None | ||||
|  | ||||
|     @staticmethod | ||||
|     def from_descriptor(version: str) -> "ZStandard": | ||||
|         if version.partition(".")[0] != "1": | ||||
|             raise ValueError(f"Unknown extension {version=}") | ||||
|         if not ZStandard.is_available(): | ||||
|             raise ValueError( | ||||
|                 f"Stream with ZStandard compression cannot be processed because no module named '{zstandard_module_name}'" | ||||
|             ) | ||||
|         return ZStandard() | ||||
|  | ||||
|     @staticmethod | ||||
|     def registry_name() -> str: | ||||
|         return "stream.zstd" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         if not ZStandard.is_available(): | ||||
|             raise ValueError( | ||||
|                 f"ZStandard extension is unavailable because no module named '{zstandard_module_name}'" | ||||
|             ) | ||||
|  | ||||
|     def get_descriptor(self) -> str: | ||||
|         return f"{self.registry_name()}/1" | ||||
|  | ||||
|     def transform_to(self, output: IO[bytes]) -> IO[bytes]: | ||||
|         compressor = zstandard.ZstdCompressor()  # type: ignore[union-attr] | ||||
|         return compressor.stream_writer(output) | ||||
|  | ||||
|     def transform_from(self, input: IO[bytes]) -> IO[bytes]: | ||||
|         decompressor = zstandard.ZstdDecompressor()  # type: ignore[union-attr] | ||||
|         return decompressor.stream_reader(input) | ||||
|  | ||||
|  | ||||
| class ExtensionRegistry: | ||||
|     def __init__(self) -> None: | ||||
|         # Populate default registry contents | ||||
|         self.extensions: dict[str, Type[Extension]] = {} | ||||
|         self.extensions: dict[str, Type[Extension]] = { | ||||
|             cls.registry_name(): cls for cls in (ZStandard,) | ||||
|         } | ||||
|  | ||||
|     def register(self, cls: Type[Extension]) -> None: | ||||
|         self.extensions[cls.registry_name()] = cls | ||||
|  | ||||
		Reference in New Issue
	
	Block a user