mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200 Approved by: https://github.com/bobrenjc93
162 lines
5.1 KiB
Python
162 lines
5.1 KiB
Python
# mypy: allow-untyped-defs
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import io
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from functools import wraps
|
|
from typing import Any, Callable, cast, IO, Optional
|
|
|
|
# introduced as collections.abc.Buffer in Python 3.12
|
|
from typing_extensions import Buffer
|
|
|
|
import torch.distributed as dist
|
|
from torch.distributed.checkpoint._extension import (
|
|
ExtensionRegistry,
|
|
StreamTransformExtension,
|
|
)
|
|
|
|
|
|
class Rot13Example(StreamTransformExtension):
|
|
"""
|
|
This is an example stream transform extension which just does rot13 on each
|
|
alphanumeric character of the stream. It is mainly intended as a demonstration
|
|
and for testing; there isn't a production use case for this.
|
|
"""
|
|
|
|
def __init__(self, chunk_size: int = io.DEFAULT_BUFFER_SIZE) -> None:
|
|
super().__init__()
|
|
self._chunk_size = chunk_size
|
|
|
|
@staticmethod
|
|
def from_descriptor(version: str) -> "Rot13Example":
|
|
if version.partition(".")[0] != "1":
|
|
raise ValueError(f"Unknown extension {version=}")
|
|
return Rot13Example()
|
|
|
|
@staticmethod
|
|
def registry_name() -> str:
|
|
return "stream.rot13"
|
|
|
|
def get_descriptor(self) -> str:
|
|
return f"{self.registry_name()}/1"
|
|
|
|
@staticmethod
|
|
def _rot13bytes(b: Buffer, count: int) -> None:
|
|
b = memoryview(b)
|
|
for i in range(count):
|
|
ch = b[i]
|
|
if ch >= ord("A") and ch <= ord("Z"):
|
|
ch += ord("a") - ord("A")
|
|
elif ch >= ord("a") and ch <= ord("z"):
|
|
ch += ord("A") - ord("a")
|
|
b[i] = ch
|
|
|
|
def transform_to(self, output: IO[bytes]) -> IO[bytes]:
|
|
class Writer(io.RawIOBase):
|
|
def __init__(self, output: IO[bytes]) -> None:
|
|
self.output = output
|
|
|
|
def writeable(self) -> bool:
|
|
return True
|
|
|
|
def write(self, b: Buffer) -> Optional[int]:
|
|
# Don't mutate the input
|
|
chunk = bytearray(b)
|
|
Rot13Example._rot13bytes(chunk, len(chunk))
|
|
return self.output.write(chunk)
|
|
|
|
def flush(self) -> None:
|
|
self.output.flush()
|
|
|
|
return cast(IO[bytes], Writer(output))
|
|
|
|
def transform_from(self, input: IO[bytes]) -> IO[bytes]:
|
|
class Reader(io.RawIOBase):
|
|
def __init__(self, input: IO[bytes]) -> None:
|
|
self.input = input
|
|
|
|
def readable(self) -> bool:
|
|
return True
|
|
|
|
def readinto(self, b: Buffer) -> Optional[int]:
|
|
if hasattr(self.input, "readinto"):
|
|
count = self.input.readinto(b)
|
|
else:
|
|
# It's possible self.input is an IO[bytes] with no readinto method.
|
|
# In that case, we emulate with a read and copy. In practice,
|
|
# all of the current concrete extensions have readinto.
|
|
# 0 as a flags value is janky, but the flag values aren't available
|
|
# in python until 3.12.
|
|
view = b.__buffer__(0)
|
|
r = self.input.read(len(view))
|
|
if r is None:
|
|
count = None
|
|
else:
|
|
count = len(r)
|
|
view[:count] = r
|
|
if count == 0 or count is None:
|
|
return count
|
|
|
|
Rot13Example._rot13bytes(b, count)
|
|
return count
|
|
|
|
def seekable(self) -> bool:
|
|
return self.input.seekable()
|
|
|
|
def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
|
|
return self.input.seek(offset, whence)
|
|
|
|
def tell(self) -> int:
|
|
return self.input.tell()
|
|
|
|
return cast(IO[bytes], Reader(input))
|
|
|
|
|
|
def get_test_extension_registry() -> ExtensionRegistry:
|
|
registry = ExtensionRegistry()
|
|
registry.register(Rot13Example)
|
|
return registry
|
|
|
|
|
|
def with_temp_dir(
|
|
func: Optional[Callable] = None,
|
|
) -> Optional[Callable]:
|
|
"""
|
|
Wrapper to initialize temp directory for distributed checkpoint.
|
|
"""
|
|
assert func is not None
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
|
|
if dist.is_initialized():
|
|
# Only create temp_dir when rank is 0
|
|
if dist.get_rank() == 0:
|
|
temp_dir = tempfile.mkdtemp()
|
|
print(f"Using temp directory: {temp_dir}")
|
|
else:
|
|
temp_dir = ""
|
|
object_list = [temp_dir]
|
|
|
|
# Broadcast temp_dir to all the other ranks
|
|
os.sync()
|
|
dist.broadcast_object_list(object_list)
|
|
self.temp_dir = object_list[0]
|
|
os.sync()
|
|
else:
|
|
temp_dir = tempfile.mkdtemp()
|
|
print(f"No process group initialized, using temp directory: {temp_dir}")
|
|
self.temp_dir = temp_dir
|
|
|
|
try:
|
|
func(self, *args, **kwargs)
|
|
finally:
|
|
if dist.is_initialized() and dist.get_rank() == 0:
|
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
else:
|
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
|
|
|
return wrapper
|