From 835e770bad6a2787b6d107808eed3a513de1d8b1 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 27 Jan 2025 18:08:05 +0000 Subject: [PATCH] Use `typing.IO[bytes]` instead of `io.BytesIO` in annotations (#144994) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #144976 Using appoach ① `IO[bytes]`, but could also try with a protocol. ## Notes: - moved `torch.serialization.FILE_LIKE` to `torch.types.FileLike` - Use `FileLike` annotation where it makes sense - made sure those functions also support `os.PathLike` - Replaced `isinstance(x, io.BytesIO)` with `isinstance(x, (io.IOBase, IO))` where appropriate. - Replaced `BinaryIO` with `IO[bytes]` (the two ABCs are almost identical, the only difference is that `BinaryIO` allows `bytearray` input to `write`, whereas `IO[bytes]` only `bytes`) - needed to make `torch.serialization._opener` generic to avoid LSP violations. - skipped `torch/onnx/verification` for now (functions use `BytesIO.getvalue` which is not part of the `IO[bytes]` ABC, but it kind of seems that this is redundant, as e.g. `onnx.load` supports `str | PathLike[str] | IO[bytes]` directly... Pull Request resolved: https://github.com/pytorch/pytorch/pull/144994 Approved by: https://github.com/ezyang, https://github.com/Skylion007 --- test/allowlist_for_publicAPI.json | 1 + torch/_C/__init__.pyi.in | 20 +++---- torch/_dynamo/replay_record.py | 4 +- torch/_inductor/__init__.py | 19 ++++--- .../_inductor/compile_worker/subproc_pool.py | 10 ++-- torch/_inductor/debug.py | 4 +- torch/_inductor/package/package.py | 35 +++++++----- torch/export/__init__.py | 10 ++-- torch/onnx/_internal/fx/serialization.py | 20 ++++--- torch/package/package_exporter.py | 15 +++--- torch/package/package_importer.py | 5 +- torch/serialization.py | 53 ++++++++++--------- torch/types.py | 7 ++- torch/utils/show_pickle.py | 4 +- 14 files changed, 120 insertions(+), 87 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index b81fe3929eb9..5e9faf9fe9ed 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1028,6 +1028,7 @@ "torch.types": [ "Any", "Device", + "FileLike", "List", "Number", "Sequence", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9bafcbe3c9c6..1f0b0ac33fb9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -8,11 +8,11 @@ from pathlib import Path from typing import ( Any, AnyStr, - BinaryIO, Callable, ContextManager, Dict, Generic, + IO, Iterable, Iterator, List, @@ -478,20 +478,20 @@ def _load_for_lite_interpreter( map_location: Optional[DeviceLikeType], ): ... def _load_for_lite_interpreter_from_buffer( - buffer: BinaryIO, + buffer: IO[bytes], map_location: Optional[DeviceLikeType], ): ... def _export_operator_list(module: LiteScriptModule): ... def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ... -def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ... +def _get_model_bytecode_version_from_buffer(buffer: IO[bytes]) -> _int: ... def _backport_for_mobile( filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int, ) -> None: ... def _backport_for_mobile_from_buffer( - buffer: BinaryIO, + buffer: IO[bytes], filename_output: Union[str, Path], to_version: _int, ) -> None: ... @@ -500,13 +500,13 @@ def _backport_for_mobile_to_buffer( to_version: _int, ) -> bytes: ... def _backport_for_mobile_from_buffer_to_buffer( - buffer: BinaryIO, + buffer: IO[bytes], to_version: _int, ) -> bytes: ... def _get_model_ops_and_info(filename: Union[str, Path]): ... -def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ... +def _get_model_ops_and_info_from_buffer(buffer: IO[bytes]): ... def _get_mobile_model_contained_types(filename: Union[str, Path]): ... -def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ... +def _get_mobile_model_contained_types_from_buffer(buffer: IO[bytes]): ... def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ... def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ... def _set_graph_executor_optimize(optimize: _bool): ... @@ -730,7 +730,7 @@ def import_ir_module( ) -> ScriptModule: ... def import_ir_module_from_buffer( cu: CompilationUnit, - buffer: BinaryIO, + buffer: IO[bytes], map_location: Optional[DeviceLikeType], extra_files: Dict[str, Any], ) -> ScriptModule: ... @@ -1465,7 +1465,7 @@ class PyTorchFileReader: @overload def __init__(self, name: str) -> None: ... @overload - def __init__(self, buffer: BinaryIO) -> None: ... + def __init__(self, buffer: IO[bytes]) -> None: ... def get_record(self, name: str) -> bytes: ... def serialization_id(self) -> str: ... @@ -1473,7 +1473,7 @@ class PyTorchFileWriter: @overload def __init__(self, name: str, compute_crc32 = True) -> None: ... @overload - def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ... + def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ... def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... def write_end_of_file(self) -> None: ... def set_min_version(self, version: _int) -> None: ... diff --git a/torch/_dynamo/replay_record.py b/torch/_dynamo/replay_record.py index e9fb56f64046..e1087957b4b6 100644 --- a/torch/_dynamo/replay_record.py +++ b/torch/_dynamo/replay_record.py @@ -1,7 +1,7 @@ import dataclasses from dataclasses import field from types import CellType, CodeType, ModuleType -from typing import Any, BinaryIO, IO +from typing import Any, IO from typing_extensions import Self from torch.utils._import_utils import import_dill @@ -40,7 +40,7 @@ class ExecutionRecord: dill.dump(self, f) @classmethod - def load(cls, f: BinaryIO) -> Self: + def load(cls, f: IO[bytes]) -> Self: assert dill is not None, "replay_record requires `pip install dill`" return dill.load(f) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 9a966f4f5db0..7fef0577366c 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations import io import logging import os -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union import torch._inductor.config import torch.fx @@ -13,7 +13,7 @@ import torch.fx if TYPE_CHECKING: from torch._inductor.utils import InputType from torch.export import ExportedProgram - + from torch.types import FileLike __all__ = [ "compile", @@ -53,7 +53,7 @@ def aoti_compile_and_package( _deprecated_unused_args=None, _deprecated_unused_kwargs=None, *, - package_path: Optional[Union[str, io.BytesIO]] = None, + package_path: Optional[FileLike] = None, inductor_configs: Optional[dict[str, Any]] = None, ) -> str: """ @@ -105,8 +105,15 @@ def aoti_compile_and_package( assert ( package_path is None - or isinstance(package_path, io.BytesIO) - or (isinstance(package_path, str) and package_path.endswith(".pt2")) + or ( + isinstance(package_path, (io.IOBase, IO)) + and package_path.writable() + and package_path.seekable() + ) + or ( + isinstance(package_path, (str, os.PathLike)) + and os.fspath(package_path).endswith(".pt2") + ) ), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" inductor_configs = inductor_configs or {} @@ -207,7 +214,7 @@ def _aoti_compile_and_package_inner( return package_path -def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type-arg] +def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg] """ Loads the model from the PT2 package. diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index ccbc269c2891..8f6761e3d197 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -13,7 +13,7 @@ import typing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool from enum import Enum -from typing import Any, BinaryIO, Callable, Optional, TypeVar +from typing import Any, Callable, IO, Optional, TypeVar from typing_extensions import Never, ParamSpec # _thread_safe_fork is needed because the subprocesses in the pool can read @@ -43,7 +43,7 @@ def _unpack_msg(data: bytes) -> tuple[int, int]: msg_bytes = len(_pack_msg(0, 0)) -def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None: +def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None: length = len(job_data) write_pipe.write(_pack_msg(job_id, length)) if length > 0: @@ -51,7 +51,7 @@ def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None: write_pipe.flush() -def _recv_msg(read_pipe: BinaryIO) -> tuple[int, bytes]: +def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]: job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) data = read_pipe.read(length) if length > 0 else b"" return job_id, data @@ -255,8 +255,8 @@ class SubprocMain: pickler: SubprocPickler, kind: SubprocKind, nprocs: int, - read_pipe: BinaryIO, - write_pipe: BinaryIO, + read_pipe: IO[bytes], + write_pipe: IO[bytes], ) -> None: self.pickler = pickler self.kind = kind diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 82633a17680b..a16595777f25 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -3,7 +3,6 @@ import contextlib import copy import dataclasses import functools -import io import itertools import json import logging @@ -25,6 +24,7 @@ from torch._dynamo.utils import get_debug_dir from torch.fx.graph_module import GraphModule from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.fx.passes.tools_common import legalize_graph +from torch.types import FileLike from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map @@ -728,7 +728,7 @@ def aot_inductor_minifier_wrapper( exported_program: torch.export.ExportedProgram, *, inductor_configs: dict[str, Any], - package_path: Optional[Union[str, io.BytesIO]] = None, + package_path: Optional[FileLike] = None, ) -> str: from torch._dynamo.debug_utils import AccuracyError from torch._dynamo.repro.aoti import dump_to_minify diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index 9b9c09786e10..355db67a1280 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -7,7 +7,7 @@ import subprocess import tempfile import zipfile from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, IO, Optional, Union import torch import torch._inductor @@ -15,6 +15,7 @@ import torch.utils._pytree as pytree from torch._inductor import exc from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch.export._tree_utils import reorder_kwargs +from torch.types import FileLike from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION @@ -23,8 +24,8 @@ log = logging.getLogger(__name__) class PT2ArchiveWriter: - def __init__(self, archive_path: Union[str, io.BytesIO]) -> None: - self.archive_path: Union[str, io.BytesIO] = archive_path + def __init__(self, archive_path: FileLike) -> None: + self.archive_path: FileLike = archive_path self.archive_file: Optional[zipfile.ZipFile] = None def __enter__(self) -> "PT2ArchiveWriter": @@ -158,9 +159,9 @@ def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: def package_aoti( - archive_file: Union[str, io.BytesIO], + archive_file: FileLike, aoti_files: Union[list[str], dict[str, list[str]]], -) -> Union[str, io.BytesIO]: +) -> FileLike: """ Saves the AOTInductor generated files to the PT2Archive format. @@ -179,8 +180,13 @@ def package_aoti( "files. You can get this list of files through calling " "`torch._inductor.aot_compile(..., options={aot_inductor.package=True})`" ) - assert isinstance(archive_file, io.BytesIO) or ( - isinstance(archive_file, str) and archive_file.endswith(".pt2") + assert ( + isinstance(archive_file, (io.IOBase, IO)) + and archive_file.writable() + and archive_file.seekable() + ) or ( + isinstance(archive_file, (str, os.PathLike)) + and os.fspath(archive_file).endswith(".pt2") ), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}" # Save using the PT2 packaging format @@ -222,7 +228,7 @@ def package_aoti( file, ) - if isinstance(archive_file, io.BytesIO): + if isinstance(archive_file, (io.IOBase, IO)): archive_file.seek(0) return archive_file @@ -268,19 +274,21 @@ class AOTICompiledModel: def get_constant_fqns(self) -> list[str]: return self.loader.get_constant_fqns() # type: ignore[attr-defined] - def __deepcopy__(self, memo: Optional[Dict[Any, Any]]) -> "AOTICompiledModel": + def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": log.warning( "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." ) return AOTICompiledModel(self.loader) # type: ignore[attr-defined] -def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] - assert isinstance(path, io.BytesIO) or ( - isinstance(path, str) and path.endswith(".pt2") +def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] + assert ( + isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable() + ) or ( + isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2") ), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}" - if isinstance(path, io.BytesIO): + if isinstance(path, (io.IOBase, IO)): with tempfile.NamedTemporaryFile(suffix=".pt2") as f: # TODO(angelayi): We shouldn't need to do this -- miniz should # handle reading the buffer. This is just a temporary workaround @@ -290,5 +298,6 @@ def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOT loader = torch._C._aoti.AOTIModelPackageLoader(f.name, model_name) # type: ignore[call-arg] return AOTICompiledModel(loader) + path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg] return AOTICompiledModel(loader) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 4bf904fa905f..9a682b104df4 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -2,7 +2,6 @@ import builtins import copy import dataclasses import inspect -import io import os import sys import typing @@ -27,6 +26,7 @@ import torch.utils._pytree as pytree from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager +from torch.types import FileLike from torch.utils._pytree import ( FlattenFunc, FromDumpableContextFn, @@ -381,7 +381,7 @@ DEFAULT_PICKLE_PROTOCOL = 2 def save( ep: ExportedProgram, - f: Union[str, os.PathLike, io.BytesIO], + f: FileLike, *, extra_files: Optional[dict[str, Any]] = None, opset_version: Optional[dict[str, int]] = None, @@ -399,7 +399,7 @@ def save( Args: ep (ExportedProgram): The exported program to save. - f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to implement write and flush) or a string containing a file name. extra_files (Optional[Dict[str, Any]]): Map from filename to contents @@ -464,7 +464,7 @@ def save( def load( - f: Union[str, os.PathLike, io.BytesIO], + f: FileLike, *, extra_files: Optional[dict[str, Any]] = None, expected_opset_version: Optional[dict[str, int]] = None, @@ -479,7 +479,7 @@ def load( :func:`torch.export.save `. Args: - f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to implement write and flush) or a string containing a file name. extra_files (Optional[Dict[str, Any]]): The extra filenames given in diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py index 8720ecf3460d..cda71e465758 100644 --- a/torch/onnx/_internal/fx/serialization.py +++ b/torch/onnx/_internal/fx/serialization.py @@ -4,7 +4,7 @@ from __future__ import annotations import io import logging import os -from typing import TYPE_CHECKING +from typing import IO, TYPE_CHECKING import torch from torch.onnx import _type_utils as jit_type_utils @@ -13,6 +13,8 @@ from torch.onnx import _type_utils as jit_type_utils if TYPE_CHECKING: import onnx + from torch.types import FileLike + log = logging.getLogger(__name__) @@ -117,7 +119,7 @@ def save_model_with_external_data( basepath: str, model_location: str, initializer_location: str, - torch_state_dicts: tuple[dict | str | io.BytesIO, ...], + torch_state_dicts: tuple[dict | FileLike, ...], onnx_model: onnx.ModelProto, # type: ignore[name-defined] rename_initializer: bool = False, ) -> None: @@ -165,7 +167,9 @@ def save_model_with_external_data( # Using torch.save wouldn't leverage mmap, leading to higher memory usage state_dict = el else: - if isinstance(el, str) and el.endswith(".safetensors"): + if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith( + ".safetensors" + ): state_dict = _convert_safetensors_to_torch_format(el) else: try: @@ -173,14 +177,16 @@ def save_model_with_external_data( # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded state_dict = torch.load(el, map_location="cpu", mmap=True) except (RuntimeError, ValueError) as e: - if "mmap can only be used with files saved with" in str( - e - ) or isinstance(el, io.BytesIO): + if "mmap can only be used with files saved with" in str(e) or ( + isinstance(el, (io.IOBase, IO)) + and el.readable() + and el.seekable() + ): log.warning( "Failed to load the checkpoint with memory-map enabled, retrying without memory-map." "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6." ) - if isinstance(el, io.BytesIO): + if isinstance(el, (io.IOBase, IO)): el.seek(0) # torch.load from `try:` has read the file. state_dict = torch.load(el, map_location="cpu") else: diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 796936b1f3ed..42e346c626e3 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -3,6 +3,7 @@ import collections import importlib.machinery import io import linecache +import os import pickletools import platform import types @@ -12,11 +13,11 @@ from dataclasses import dataclass from enum import Enum from importlib.machinery import SourceFileLoader from pathlib import Path -from typing import Any, BinaryIO, Callable, cast, Optional, Union +from typing import Any, Callable, cast, IO, Optional, Union import torch from torch.serialization import location_tag, normalize_storage_type -from torch.types import Storage +from torch.types import FileLike, Storage from torch.utils.hooks import RemovableHandle from ._digraph import DiGraph @@ -201,10 +202,10 @@ class PackageExporter: def __init__( self, - f: Union[str, Path, BinaryIO], + f: FileLike, importer: Union[Importer, Sequence[Importer]] = sys_importer, debug: bool = False, - ): + ) -> None: """ Create an exporter. @@ -217,9 +218,9 @@ class PackageExporter: """ torch._C._log_api_usage_once("torch.package.PackageExporter") self.debug = debug - if isinstance(f, (Path, str)): - f = str(f) - self.buffer: Optional[BinaryIO] = None + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + self.buffer: Optional[IO[bytes]] = None else: # is a byte buffer self.buffer = f diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 971b5398ec63..6510986e2455 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -10,11 +10,12 @@ import sys import types from collections.abc import Iterable from contextlib import contextmanager -from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from weakref import WeakValueDictionary import torch from torch.serialization import _get_restore_location, _maybe_decode_ascii +from torch.types import FileLike from ._directory_reader import DirectoryReader from ._importlib import ( @@ -84,7 +85,7 @@ class PackageImporter(Importer): def __init__( self, - file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO], + file_or_buffer: Union[FileLike, torch._C.PyTorchFileReader], module_allowed: Callable[[str], bool] = lambda module_name: True, ): """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules diff --git a/torch/serialization.py b/torch/serialization.py index c7046049f6b2..9f70baad9b87 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -15,7 +15,7 @@ import threading import warnings from contextlib import closing, contextmanager from enum import Enum -from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union +from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union from typing_extensions import TypeAlias, TypeIs import torch @@ -23,7 +23,7 @@ import torch._weights_only_unpickler as _weights_only_unpickler from torch._sources import get_source_lines_and_file from torch._utils import _import_dotted_name from torch.storage import _get_dtype_from_pickle_storage_type -from torch.types import Storage +from torch.types import FileLike, Storage __all__ = [ @@ -65,7 +65,6 @@ MAGIC_NUMBER = 0x1950A86A20F9469CFC6C PROTOCOL_VERSION = 1001 STORAGE_KEY_SEPARATOR = "," -FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] MAP_LOCATION: TypeAlias = Optional[ Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]] ] @@ -326,7 +325,7 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ -def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]: +def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]: """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``. For a given function or class ``f``, the corresponding string will be of the form @@ -702,13 +701,16 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]: +def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]: return isinstance(name_or_buffer, (str, os.PathLike)) -class _opener: - def __init__(self, file_like): - self.file_like = file_like +T = TypeVar("T") + + +class _opener(Generic[T]): + def __init__(self, file_like: T) -> None: + self.file_like: T = file_like def __enter__(self): return self.file_like @@ -717,26 +719,26 @@ class _opener: pass -class _open_file(_opener): - def __init__(self, name, mode): +class _open_file(_opener[IO[bytes]]): + def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: super().__init__(open(name, mode)) def __exit__(self, *args): self.file_like.close() -class _open_buffer_reader(_opener): - def __init__(self, buffer): +class _open_buffer_reader(_opener[IO[bytes]]): + def __init__(self, buffer: IO[bytes]) -> None: super().__init__(buffer) _check_seekable(buffer) -class _open_buffer_writer(_opener): +class _open_buffer_writer(_opener[IO[bytes]]): def __exit__(self, *args): self.file_like.flush() -def _open_file_like(name_or_buffer, mode): +def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]: if _is_path(name_or_buffer): return _open_file(name_or_buffer, mode) else: @@ -748,15 +750,15 @@ def _open_file_like(name_or_buffer, mode): raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") -class _open_zipfile_reader(_opener): - def __init__(self, name_or_buffer) -> None: +class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): + def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) -class _open_zipfile_writer_file(_opener): - def __init__(self, name) -> None: +class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, name: str) -> None: self.file_stream = None - self.name = str(name) + self.name = name try: self.name.encode("ascii") except UnicodeEncodeError: @@ -776,8 +778,8 @@ class _open_zipfile_writer_file(_opener): self.file_stream.close() -class _open_zipfile_writer_buffer(_opener): - def __init__(self, buffer) -> None: +class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]): + def __init__(self, buffer: IO[bytes]) -> None: if not callable(getattr(buffer, "write", None)): msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" if not hasattr(buffer, "write"): @@ -791,7 +793,7 @@ class _open_zipfile_writer_buffer(_opener): self.buffer.flush() -def _open_zipfile_writer(name_or_buffer): +def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener: container: type[_opener] if _is_path(name_or_buffer): container = _open_zipfile_writer_file @@ -879,7 +881,7 @@ def _check_save_filelike(f): def save( obj: object, - f: FILE_LIKE, + f: FileLike, pickle_module: Any = pickle, pickle_protocol: int = DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool = True, @@ -929,6 +931,9 @@ def save( _check_dill_version(pickle_module) _check_save_filelike(f) + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + if _use_new_zipfile_serialization: with _open_zipfile_writer(f) as opened_zipfile: _save( @@ -1222,7 +1227,7 @@ def _save( def load( - f: FILE_LIKE, + f: FileLike, map_location: MAP_LOCATION = None, pickle_module: Any = None, *, diff --git a/torch/types.py b/torch/types.py index 87f7109310ee..91a411fa5e14 100644 --- a/torch/types.py +++ b/torch/types.py @@ -4,6 +4,7 @@ # top-level values. The underscore variants let us refer to these # types. See https://github.com/python/mypy/issues/4146 for why these # workarounds is necessary +import os from builtins import ( # noqa: F401 bool as _bool, bytes as _bytes, @@ -13,7 +14,7 @@ from builtins import ( # noqa: F401 str as _str, ) from collections.abc import Sequence -from typing import Any, TYPE_CHECKING, Union +from typing import Any, IO, TYPE_CHECKING, Union from typing_extensions import TypeAlias # `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` @@ -35,7 +36,7 @@ if TYPE_CHECKING: from torch.autograd.graph import GradientEdge -__all__ = ["Number", "Device", "Storage"] +__all__ = ["Number", "Device", "FileLike", "Storage"] # Convenience aliases for common composite types that we need # to talk about in PyTorch @@ -64,6 +65,8 @@ PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] # Meta-type for "numeric" things; matches our docs Number: TypeAlias = Union[int, float, bool] +FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]] + # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index 66549fac2673..cd8b6c2b8ab9 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -6,7 +6,7 @@ import struct import pprint import zipfile import fnmatch -from typing import Any, IO, BinaryIO, Union +from typing import Any, IO __all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"] @@ -119,7 +119,7 @@ def main(argv, output_stream=None): return 2 fname = argv[1] - handle: Union[IO[bytes], BinaryIO] + handle: IO[bytes] if "@" not in fname: with open(fname, "rb") as handle: DumpUnpickler.dump(handle, output_stream)