Use typing.IO[bytes] instead of io.BytesIO in annotations (#144994)

Fixes #144976

Using appoach ① `IO[bytes]`, but could also try with a protocol.

## Notes:

- moved `torch.serialization.FILE_LIKE` to `torch.types.FileLike`
- Use `FileLike` annotation where it makes sense
- made sure those functions also support `os.PathLike`
- Replaced `isinstance(x, io.BytesIO)` with `isinstance(x, (io.IOBase, IO))` where appropriate.
- Replaced `BinaryIO` with `IO[bytes]` (the two ABCs are almost identical, the only difference is that `BinaryIO` allows `bytearray` input to `write`, whereas `IO[bytes]` only `bytes`)
- needed to make `torch.serialization._opener` generic to avoid LSP violations.
- skipped `torch/onnx/verification` for now (functions use `BytesIO.getvalue` which is not part of the `IO[bytes]` ABC, but it kind of seems that this is redundant, as e.g. `onnx.load` supports `str | PathLike[str] | IO[bytes]` directly...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144994
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Randolf Scholz
2025-01-27 18:08:05 +00:00
committed by PyTorch MergeBot
parent abf28982a8
commit 835e770bad
14 changed files with 120 additions and 87 deletions

View File

@ -1028,6 +1028,7 @@
"torch.types": [
"Any",
"Device",
"FileLike",
"List",
"Number",
"Sequence",

View File

@ -8,11 +8,11 @@ from pathlib import Path
from typing import (
Any,
AnyStr,
BinaryIO,
Callable,
ContextManager,
Dict,
Generic,
IO,
Iterable,
Iterator,
List,
@ -478,20 +478,20 @@ def _load_for_lite_interpreter(
map_location: Optional[DeviceLikeType],
): ...
def _load_for_lite_interpreter_from_buffer(
buffer: BinaryIO,
buffer: IO[bytes],
map_location: Optional[DeviceLikeType],
): ...
def _export_operator_list(module: LiteScriptModule): ...
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
def _get_model_bytecode_version_from_buffer(buffer: IO[bytes]) -> _int: ...
def _backport_for_mobile(
filename_input: Union[str, Path],
filename_output: Union[str, Path],
to_version: _int,
) -> None: ...
def _backport_for_mobile_from_buffer(
buffer: BinaryIO,
buffer: IO[bytes],
filename_output: Union[str, Path],
to_version: _int,
) -> None: ...
@ -500,13 +500,13 @@ def _backport_for_mobile_to_buffer(
to_version: _int,
) -> bytes: ...
def _backport_for_mobile_from_buffer_to_buffer(
buffer: BinaryIO,
buffer: IO[bytes],
to_version: _int,
) -> bytes: ...
def _get_model_ops_and_info(filename: Union[str, Path]): ...
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
def _get_model_ops_and_info_from_buffer(buffer: IO[bytes]): ...
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
def _get_mobile_model_contained_types_from_buffer(buffer: IO[bytes]): ...
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
def _set_graph_executor_optimize(optimize: _bool): ...
@ -730,7 +730,7 @@ def import_ir_module(
) -> ScriptModule: ...
def import_ir_module_from_buffer(
cu: CompilationUnit,
buffer: BinaryIO,
buffer: IO[bytes],
map_location: Optional[DeviceLikeType],
extra_files: Dict[str, Any],
) -> ScriptModule: ...
@ -1465,7 +1465,7 @@ class PyTorchFileReader:
@overload
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
def __init__(self, buffer: IO[bytes]) -> None: ...
def get_record(self, name: str) -> bytes: ...
def serialization_id(self) -> str: ...
@ -1473,7 +1473,7 @@ class PyTorchFileWriter:
@overload
def __init__(self, name: str, compute_crc32 = True) -> None: ...
@overload
def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ...
def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ...
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ...

View File

@ -1,7 +1,7 @@
import dataclasses
from dataclasses import field
from types import CellType, CodeType, ModuleType
from typing import Any, BinaryIO, IO
from typing import Any, IO
from typing_extensions import Self
from torch.utils._import_utils import import_dill
@ -40,7 +40,7 @@ class ExecutionRecord:
dill.dump(self, f)
@classmethod
def load(cls, f: BinaryIO) -> Self:
def load(cls, f: IO[bytes]) -> Self:
assert dill is not None, "replay_record requires `pip install dill`"
return dill.load(f)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union
import torch._inductor.config
import torch.fx
@ -13,7 +13,7 @@ import torch.fx
if TYPE_CHECKING:
from torch._inductor.utils import InputType
from torch.export import ExportedProgram
from torch.types import FileLike
__all__ = [
"compile",
@ -53,7 +53,7 @@ def aoti_compile_and_package(
_deprecated_unused_args=None,
_deprecated_unused_kwargs=None,
*,
package_path: Optional[Union[str, io.BytesIO]] = None,
package_path: Optional[FileLike] = None,
inductor_configs: Optional[dict[str, Any]] = None,
) -> str:
"""
@ -105,8 +105,15 @@ def aoti_compile_and_package(
assert (
package_path is None
or isinstance(package_path, io.BytesIO)
or (isinstance(package_path, str) and package_path.endswith(".pt2"))
or (
isinstance(package_path, (io.IOBase, IO))
and package_path.writable()
and package_path.seekable()
)
or (
isinstance(package_path, (str, os.PathLike))
and os.fspath(package_path).endswith(".pt2")
)
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
inductor_configs = inductor_configs or {}
@ -207,7 +214,7 @@ def _aoti_compile_and_package_inner(
return package_path
def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type-arg]
def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
"""
Loads the model from the PT2 package.

View File

@ -13,7 +13,7 @@ import typing
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from enum import Enum
from typing import Any, BinaryIO, Callable, Optional, TypeVar
from typing import Any, Callable, IO, Optional, TypeVar
from typing_extensions import Never, ParamSpec
# _thread_safe_fork is needed because the subprocesses in the pool can read
@ -43,7 +43,7 @@ def _unpack_msg(data: bytes) -> tuple[int, int]:
msg_bytes = len(_pack_msg(0, 0))
def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None:
def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None:
length = len(job_data)
write_pipe.write(_pack_msg(job_id, length))
if length > 0:
@ -51,7 +51,7 @@ def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None:
write_pipe.flush()
def _recv_msg(read_pipe: BinaryIO) -> tuple[int, bytes]:
def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
data = read_pipe.read(length) if length > 0 else b""
return job_id, data
@ -255,8 +255,8 @@ class SubprocMain:
pickler: SubprocPickler,
kind: SubprocKind,
nprocs: int,
read_pipe: BinaryIO,
write_pipe: BinaryIO,
read_pipe: IO[bytes],
write_pipe: IO[bytes],
) -> None:
self.pickler = pickler
self.kind = kind

View File

@ -3,7 +3,6 @@ import contextlib
import copy
import dataclasses
import functools
import io
import itertools
import json
import logging
@ -25,6 +24,7 @@ from torch._dynamo.utils import get_debug_dir
from torch.fx.graph_module import GraphModule
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.fx.passes.tools_common import legalize_graph
from torch.types import FileLike
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map
@ -728,7 +728,7 @@ def aot_inductor_minifier_wrapper(
exported_program: torch.export.ExportedProgram,
*,
inductor_configs: dict[str, Any],
package_path: Optional[Union[str, io.BytesIO]] = None,
package_path: Optional[FileLike] = None,
) -> str:
from torch._dynamo.debug_utils import AccuracyError
from torch._dynamo.repro.aoti import dump_to_minify

View File

@ -7,7 +7,7 @@ import subprocess
import tempfile
import zipfile
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, IO, Optional, Union
import torch
import torch._inductor
@ -15,6 +15,7 @@ import torch.utils._pytree as pytree
from torch._inductor import exc
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
from torch.export._tree_utils import reorder_kwargs
from torch.types import FileLike
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
@ -23,8 +24,8 @@ log = logging.getLogger(__name__)
class PT2ArchiveWriter:
def __init__(self, archive_path: Union[str, io.BytesIO]) -> None:
self.archive_path: Union[str, io.BytesIO] = archive_path
def __init__(self, archive_path: FileLike) -> None:
self.archive_path: FileLike = archive_path
self.archive_file: Optional[zipfile.ZipFile] = None
def __enter__(self) -> "PT2ArchiveWriter":
@ -158,9 +159,9 @@ def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def package_aoti(
archive_file: Union[str, io.BytesIO],
archive_file: FileLike,
aoti_files: Union[list[str], dict[str, list[str]]],
) -> Union[str, io.BytesIO]:
) -> FileLike:
"""
Saves the AOTInductor generated files to the PT2Archive format.
@ -179,8 +180,13 @@ def package_aoti(
"files. You can get this list of files through calling "
"`torch._inductor.aot_compile(..., options={aot_inductor.package=True})`"
)
assert isinstance(archive_file, io.BytesIO) or (
isinstance(archive_file, str) and archive_file.endswith(".pt2")
assert (
isinstance(archive_file, (io.IOBase, IO))
and archive_file.writable()
and archive_file.seekable()
) or (
isinstance(archive_file, (str, os.PathLike))
and os.fspath(archive_file).endswith(".pt2")
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
# Save using the PT2 packaging format
@ -222,7 +228,7 @@ def package_aoti(
file,
)
if isinstance(archive_file, io.BytesIO):
if isinstance(archive_file, (io.IOBase, IO)):
archive_file.seek(0)
return archive_file
@ -268,19 +274,21 @@ class AOTICompiledModel:
def get_constant_fqns(self) -> list[str]:
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
def __deepcopy__(self, memo: Optional[Dict[Any, Any]]) -> "AOTICompiledModel":
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
log.warning(
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
)
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
assert isinstance(path, io.BytesIO) or (
isinstance(path, str) and path.endswith(".pt2")
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
assert (
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
) or (
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
if isinstance(path, io.BytesIO):
if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
# TODO(angelayi): We shouldn't need to do this -- miniz should
# handle reading the buffer. This is just a temporary workaround
@ -290,5 +298,6 @@ def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOT
loader = torch._C._aoti.AOTIModelPackageLoader(f.name, model_name) # type: ignore[call-arg]
return AOTICompiledModel(loader)
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
return AOTICompiledModel(loader)

View File

@ -2,7 +2,6 @@ import builtins
import copy
import dataclasses
import inspect
import io
import os
import sys
import typing
@ -27,6 +26,7 @@ import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.types import FileLike
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
@ -381,7 +381,7 @@ DEFAULT_PICKLE_PROTOCOL = 2
def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
f: FileLike,
*,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
@ -399,7 +399,7 @@ def save(
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
@ -464,7 +464,7 @@ def save(
def load(
f: Union[str, os.PathLike, io.BytesIO],
f: FileLike,
*,
extra_files: Optional[dict[str, Any]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
@ -479,7 +479,7 @@ def load(
:func:`torch.export.save <torch.export.save>`.
Args:
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import logging
import os
from typing import TYPE_CHECKING
from typing import IO, TYPE_CHECKING
import torch
from torch.onnx import _type_utils as jit_type_utils
@ -13,6 +13,8 @@ from torch.onnx import _type_utils as jit_type_utils
if TYPE_CHECKING:
import onnx
from torch.types import FileLike
log = logging.getLogger(__name__)
@ -117,7 +119,7 @@ def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_state_dicts: tuple[dict | str | io.BytesIO, ...],
torch_state_dicts: tuple[dict | FileLike, ...],
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
@ -165,7 +167,9 @@ def save_model_with_external_data(
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
state_dict = el
else:
if isinstance(el, str) and el.endswith(".safetensors"):
if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith(
".safetensors"
):
state_dict = _convert_safetensors_to_torch_format(el)
else:
try:
@ -173,14 +177,16 @@ def save_model_with_external_data(
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
state_dict = torch.load(el, map_location="cpu", mmap=True)
except (RuntimeError, ValueError) as e:
if "mmap can only be used with files saved with" in str(
e
) or isinstance(el, io.BytesIO):
if "mmap can only be used with files saved with" in str(e) or (
isinstance(el, (io.IOBase, IO))
and el.readable()
and el.seekable()
):
log.warning(
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
)
if isinstance(el, io.BytesIO):
if isinstance(el, (io.IOBase, IO)):
el.seek(0) # torch.load from `try:` has read the file.
state_dict = torch.load(el, map_location="cpu")
else:

View File

@ -3,6 +3,7 @@ import collections
import importlib.machinery
import io
import linecache
import os
import pickletools
import platform
import types
@ -12,11 +13,11 @@ from dataclasses import dataclass
from enum import Enum
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import Any, BinaryIO, Callable, cast, Optional, Union
from typing import Any, Callable, cast, IO, Optional, Union
import torch
from torch.serialization import location_tag, normalize_storage_type
from torch.types import Storage
from torch.types import FileLike, Storage
from torch.utils.hooks import RemovableHandle
from ._digraph import DiGraph
@ -201,10 +202,10 @@ class PackageExporter:
def __init__(
self,
f: Union[str, Path, BinaryIO],
f: FileLike,
importer: Union[Importer, Sequence[Importer]] = sys_importer,
debug: bool = False,
):
) -> None:
"""
Create an exporter.
@ -217,9 +218,9 @@ class PackageExporter:
"""
torch._C._log_api_usage_once("torch.package.PackageExporter")
self.debug = debug
if isinstance(f, (Path, str)):
f = str(f)
self.buffer: Optional[BinaryIO] = None
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
self.buffer: Optional[IO[bytes]] = None
else: # is a byte buffer
self.buffer = f

View File

@ -10,11 +10,12 @@ import sys
import types
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from weakref import WeakValueDictionary
import torch
from torch.serialization import _get_restore_location, _maybe_decode_ascii
from torch.types import FileLike
from ._directory_reader import DirectoryReader
from ._importlib import (
@ -84,7 +85,7 @@ class PackageImporter(Importer):
def __init__(
self,
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
file_or_buffer: Union[FileLike, torch._C.PyTorchFileReader],
module_allowed: Callable[[str], bool] = lambda module_name: True,
):
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules

View File

@ -15,7 +15,7 @@ import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
from typing_extensions import TypeAlias, TypeIs
import torch
@ -23,7 +23,7 @@ import torch._weights_only_unpickler as _weights_only_unpickler
from torch._sources import get_source_lines_and_file
from torch._utils import _import_dotted_name
from torch.storage import _get_dtype_from_pickle_storage_type
from torch.types import Storage
from torch.types import FileLike, Storage
__all__ = [
@ -65,7 +65,6 @@ MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ","
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MAP_LOCATION: TypeAlias = Optional[
Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]]
]
@ -326,7 +325,7 @@ class safe_globals(_weights_only_unpickler._safe_globals):
"""
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]:
def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
"""Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
For a given function or class ``f``, the corresponding string will be of the form
@ -702,13 +701,16 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]:
return isinstance(name_or_buffer, (str, os.PathLike))
class _opener:
def __init__(self, file_like):
self.file_like = file_like
T = TypeVar("T")
class _opener(Generic[T]):
def __init__(self, file_like: T) -> None:
self.file_like: T = file_like
def __enter__(self):
return self.file_like
@ -717,26 +719,26 @@ class _opener:
pass
class _open_file(_opener):
def __init__(self, name, mode):
class _open_file(_opener[IO[bytes]]):
def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
super().__init__(open(name, mode))
def __exit__(self, *args):
self.file_like.close()
class _open_buffer_reader(_opener):
def __init__(self, buffer):
class _open_buffer_reader(_opener[IO[bytes]]):
def __init__(self, buffer: IO[bytes]) -> None:
super().__init__(buffer)
_check_seekable(buffer)
class _open_buffer_writer(_opener):
class _open_buffer_writer(_opener[IO[bytes]]):
def __exit__(self, *args):
self.file_like.flush()
def _open_file_like(name_or_buffer, mode):
def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
if _is_path(name_or_buffer):
return _open_file(name_or_buffer, mode)
else:
@ -748,15 +750,15 @@ def _open_file_like(name_or_buffer, mode):
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer) -> None:
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
def __init__(self, name) -> None:
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
def __init__(self, name: str) -> None:
self.file_stream = None
self.name = str(name)
self.name = name
try:
self.name.encode("ascii")
except UnicodeEncodeError:
@ -776,8 +778,8 @@ class _open_zipfile_writer_file(_opener):
self.file_stream.close()
class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer) -> None:
class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
def __init__(self, buffer: IO[bytes]) -> None:
if not callable(getattr(buffer, "write", None)):
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
if not hasattr(buffer, "write"):
@ -791,7 +793,7 @@ class _open_zipfile_writer_buffer(_opener):
self.buffer.flush()
def _open_zipfile_writer(name_or_buffer):
def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener:
container: type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
@ -879,7 +881,7 @@ def _check_save_filelike(f):
def save(
obj: object,
f: FILE_LIKE,
f: FileLike,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
@ -929,6 +931,9 @@ def save(
_check_dill_version(pickle_module)
_check_save_filelike(f)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
_save(
@ -1222,7 +1227,7 @@ def _save(
def load(
f: FILE_LIKE,
f: FileLike,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,

View File

@ -4,6 +4,7 @@
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
import os
from builtins import ( # noqa: F401
bool as _bool,
bytes as _bytes,
@ -13,7 +14,7 @@ from builtins import ( # noqa: F401
str as _str,
)
from collections.abc import Sequence
from typing import Any, TYPE_CHECKING, Union
from typing import Any, IO, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
@ -35,7 +36,7 @@ if TYPE_CHECKING:
from torch.autograd.graph import GradientEdge
__all__ = ["Number", "Device", "Storage"]
__all__ = ["Number", "Device", "FileLike", "Storage"]
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
@ -64,6 +65,8 @@ PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool]
# Meta-type for "numeric" things; matches our docs
Number: TypeAlias = Union[int, float, bool]
FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]]
# Meta-type for "device-like" things. Not to be confused with 'device' (a
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)

View File

@ -6,7 +6,7 @@ import struct
import pprint
import zipfile
import fnmatch
from typing import Any, IO, BinaryIO, Union
from typing import Any, IO
__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
@ -119,7 +119,7 @@ def main(argv, output_stream=None):
return 2
fname = argv[1]
handle: Union[IO[bytes], BinaryIO]
handle: IO[bytes]
if "@" not in fname:
with open(fname, "rb") as handle:
DumpUnpickler.dump(handle, output_stream)