mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Use typing.IO[bytes]
instead of io.BytesIO
in annotations (#144994)
Fixes #144976 Using appoach ① `IO[bytes]`, but could also try with a protocol. ## Notes: - moved `torch.serialization.FILE_LIKE` to `torch.types.FileLike` - Use `FileLike` annotation where it makes sense - made sure those functions also support `os.PathLike` - Replaced `isinstance(x, io.BytesIO)` with `isinstance(x, (io.IOBase, IO))` where appropriate. - Replaced `BinaryIO` with `IO[bytes]` (the two ABCs are almost identical, the only difference is that `BinaryIO` allows `bytearray` input to `write`, whereas `IO[bytes]` only `bytes`) - needed to make `torch.serialization._opener` generic to avoid LSP violations. - skipped `torch/onnx/verification` for now (functions use `BytesIO.getvalue` which is not part of the `IO[bytes]` ABC, but it kind of seems that this is redundant, as e.g. `onnx.load` supports `str | PathLike[str] | IO[bytes]` directly... Pull Request resolved: https://github.com/pytorch/pytorch/pull/144994 Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
abf28982a8
commit
835e770bad
@ -1028,6 +1028,7 @@
|
||||
"torch.types": [
|
||||
"Any",
|
||||
"Device",
|
||||
"FileLike",
|
||||
"List",
|
||||
"Number",
|
||||
"Sequence",
|
||||
|
@ -8,11 +8,11 @@ from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generic,
|
||||
IO,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@ -478,20 +478,20 @@ def _load_for_lite_interpreter(
|
||||
map_location: Optional[DeviceLikeType],
|
||||
): ...
|
||||
def _load_for_lite_interpreter_from_buffer(
|
||||
buffer: BinaryIO,
|
||||
buffer: IO[bytes],
|
||||
map_location: Optional[DeviceLikeType],
|
||||
): ...
|
||||
def _export_operator_list(module: LiteScriptModule): ...
|
||||
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
|
||||
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
|
||||
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
|
||||
def _get_model_bytecode_version_from_buffer(buffer: IO[bytes]) -> _int: ...
|
||||
def _backport_for_mobile(
|
||||
filename_input: Union[str, Path],
|
||||
filename_output: Union[str, Path],
|
||||
to_version: _int,
|
||||
) -> None: ...
|
||||
def _backport_for_mobile_from_buffer(
|
||||
buffer: BinaryIO,
|
||||
buffer: IO[bytes],
|
||||
filename_output: Union[str, Path],
|
||||
to_version: _int,
|
||||
) -> None: ...
|
||||
@ -500,13 +500,13 @@ def _backport_for_mobile_to_buffer(
|
||||
to_version: _int,
|
||||
) -> bytes: ...
|
||||
def _backport_for_mobile_from_buffer_to_buffer(
|
||||
buffer: BinaryIO,
|
||||
buffer: IO[bytes],
|
||||
to_version: _int,
|
||||
) -> bytes: ...
|
||||
def _get_model_ops_and_info(filename: Union[str, Path]): ...
|
||||
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
|
||||
def _get_model_ops_and_info_from_buffer(buffer: IO[bytes]): ...
|
||||
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
|
||||
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
|
||||
def _get_mobile_model_contained_types_from_buffer(buffer: IO[bytes]): ...
|
||||
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
|
||||
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
|
||||
def _set_graph_executor_optimize(optimize: _bool): ...
|
||||
@ -730,7 +730,7 @@ def import_ir_module(
|
||||
) -> ScriptModule: ...
|
||||
def import_ir_module_from_buffer(
|
||||
cu: CompilationUnit,
|
||||
buffer: BinaryIO,
|
||||
buffer: IO[bytes],
|
||||
map_location: Optional[DeviceLikeType],
|
||||
extra_files: Dict[str, Any],
|
||||
) -> ScriptModule: ...
|
||||
@ -1465,7 +1465,7 @@ class PyTorchFileReader:
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def __init__(self, buffer: IO[bytes]) -> None: ...
|
||||
def get_record(self, name: str) -> bytes: ...
|
||||
def serialization_id(self) -> str: ...
|
||||
|
||||
@ -1473,7 +1473,7 @@ class PyTorchFileWriter:
|
||||
@overload
|
||||
def __init__(self, name: str, compute_crc32 = True) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ...
|
||||
def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ...
|
||||
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
|
@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
from dataclasses import field
|
||||
from types import CellType, CodeType, ModuleType
|
||||
from typing import Any, BinaryIO, IO
|
||||
from typing import Any, IO
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch.utils._import_utils import import_dill
|
||||
@ -40,7 +40,7 @@ class ExecutionRecord:
|
||||
dill.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, f: BinaryIO) -> Self:
|
||||
def load(cls, f: IO[bytes]) -> Self:
|
||||
assert dill is not None, "replay_record requires `pip install dill`"
|
||||
return dill.load(f)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, IO, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch._inductor.config
|
||||
import torch.fx
|
||||
@ -13,7 +13,7 @@ import torch.fx
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.export import ExportedProgram
|
||||
|
||||
from torch.types import FileLike
|
||||
|
||||
__all__ = [
|
||||
"compile",
|
||||
@ -53,7 +53,7 @@ def aoti_compile_and_package(
|
||||
_deprecated_unused_args=None,
|
||||
_deprecated_unused_kwargs=None,
|
||||
*,
|
||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||
package_path: Optional[FileLike] = None,
|
||||
inductor_configs: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
@ -105,8 +105,15 @@ def aoti_compile_and_package(
|
||||
|
||||
assert (
|
||||
package_path is None
|
||||
or isinstance(package_path, io.BytesIO)
|
||||
or (isinstance(package_path, str) and package_path.endswith(".pt2"))
|
||||
or (
|
||||
isinstance(package_path, (io.IOBase, IO))
|
||||
and package_path.writable()
|
||||
and package_path.seekable()
|
||||
)
|
||||
or (
|
||||
isinstance(package_path, (str, os.PathLike))
|
||||
and os.fspath(package_path).endswith(".pt2")
|
||||
)
|
||||
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
|
||||
|
||||
inductor_configs = inductor_configs or {}
|
||||
@ -207,7 +214,7 @@ def _aoti_compile_and_package_inner(
|
||||
return package_path
|
||||
|
||||
|
||||
def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type-arg]
|
||||
def aoti_load_package(path: FileLike) -> Any: # type: ignore[type-arg]
|
||||
"""
|
||||
Loads the model from the PT2 package.
|
||||
|
||||
|
@ -13,7 +13,7 @@ import typing
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from enum import Enum
|
||||
from typing import Any, BinaryIO, Callable, Optional, TypeVar
|
||||
from typing import Any, Callable, IO, Optional, TypeVar
|
||||
from typing_extensions import Never, ParamSpec
|
||||
|
||||
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
||||
@ -43,7 +43,7 @@ def _unpack_msg(data: bytes) -> tuple[int, int]:
|
||||
msg_bytes = len(_pack_msg(0, 0))
|
||||
|
||||
|
||||
def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None:
|
||||
def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None:
|
||||
length = len(job_data)
|
||||
write_pipe.write(_pack_msg(job_id, length))
|
||||
if length > 0:
|
||||
@ -51,7 +51,7 @@ def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None:
|
||||
write_pipe.flush()
|
||||
|
||||
|
||||
def _recv_msg(read_pipe: BinaryIO) -> tuple[int, bytes]:
|
||||
def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]:
|
||||
job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
|
||||
data = read_pipe.read(length) if length > 0 else b""
|
||||
return job_id, data
|
||||
@ -255,8 +255,8 @@ class SubprocMain:
|
||||
pickler: SubprocPickler,
|
||||
kind: SubprocKind,
|
||||
nprocs: int,
|
||||
read_pipe: BinaryIO,
|
||||
write_pipe: BinaryIO,
|
||||
read_pipe: IO[bytes],
|
||||
write_pipe: IO[bytes],
|
||||
) -> None:
|
||||
self.pickler = pickler
|
||||
self.kind = kind
|
||||
|
@ -3,7 +3,6 @@ import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@ -25,6 +24,7 @@ from torch._dynamo.utils import get_debug_dir
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
||||
from torch.fx.passes.tools_common import legalize_graph
|
||||
from torch.types import FileLike
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
@ -728,7 +728,7 @@ def aot_inductor_minifier_wrapper(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
*,
|
||||
inductor_configs: dict[str, Any],
|
||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||
package_path: Optional[FileLike] = None,
|
||||
) -> str:
|
||||
from torch._dynamo.debug_utils import AccuracyError
|
||||
from torch._dynamo.repro.aoti import dump_to_minify
|
||||
|
@ -7,7 +7,7 @@ import subprocess
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, IO, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
@ -15,6 +15,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._inductor import exc
|
||||
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
from torch.types import FileLike
|
||||
|
||||
from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION
|
||||
|
||||
@ -23,8 +24,8 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PT2ArchiveWriter:
|
||||
def __init__(self, archive_path: Union[str, io.BytesIO]) -> None:
|
||||
self.archive_path: Union[str, io.BytesIO] = archive_path
|
||||
def __init__(self, archive_path: FileLike) -> None:
|
||||
self.archive_path: FileLike = archive_path
|
||||
self.archive_file: Optional[zipfile.ZipFile] = None
|
||||
|
||||
def __enter__(self) -> "PT2ArchiveWriter":
|
||||
@ -158,9 +159,9 @@ def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||
|
||||
|
||||
def package_aoti(
|
||||
archive_file: Union[str, io.BytesIO],
|
||||
archive_file: FileLike,
|
||||
aoti_files: Union[list[str], dict[str, list[str]]],
|
||||
) -> Union[str, io.BytesIO]:
|
||||
) -> FileLike:
|
||||
"""
|
||||
Saves the AOTInductor generated files to the PT2Archive format.
|
||||
|
||||
@ -179,8 +180,13 @@ def package_aoti(
|
||||
"files. You can get this list of files through calling "
|
||||
"`torch._inductor.aot_compile(..., options={aot_inductor.package=True})`"
|
||||
)
|
||||
assert isinstance(archive_file, io.BytesIO) or (
|
||||
isinstance(archive_file, str) and archive_file.endswith(".pt2")
|
||||
assert (
|
||||
isinstance(archive_file, (io.IOBase, IO))
|
||||
and archive_file.writable()
|
||||
and archive_file.seekable()
|
||||
) or (
|
||||
isinstance(archive_file, (str, os.PathLike))
|
||||
and os.fspath(archive_file).endswith(".pt2")
|
||||
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
|
||||
|
||||
# Save using the PT2 packaging format
|
||||
@ -222,7 +228,7 @@ def package_aoti(
|
||||
file,
|
||||
)
|
||||
|
||||
if isinstance(archive_file, io.BytesIO):
|
||||
if isinstance(archive_file, (io.IOBase, IO)):
|
||||
archive_file.seek(0)
|
||||
return archive_file
|
||||
|
||||
@ -268,19 +274,21 @@ class AOTICompiledModel:
|
||||
def get_constant_fqns(self) -> list[str]:
|
||||
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
||||
|
||||
def __deepcopy__(self, memo: Optional[Dict[Any, Any]]) -> "AOTICompiledModel":
|
||||
def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel":
|
||||
log.warning(
|
||||
"AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied."
|
||||
)
|
||||
return AOTICompiledModel(self.loader) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
assert isinstance(path, io.BytesIO) or (
|
||||
isinstance(path, str) and path.endswith(".pt2")
|
||||
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
assert (
|
||||
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
|
||||
) or (
|
||||
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
|
||||
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
|
||||
|
||||
if isinstance(path, io.BytesIO):
|
||||
if isinstance(path, (io.IOBase, IO)):
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
# TODO(angelayi): We shouldn't need to do this -- miniz should
|
||||
# handle reading the buffer. This is just a temporary workaround
|
||||
@ -290,5 +298,6 @@ def load_package(path: Union[str, io.BytesIO], model_name: str = "model") -> AOT
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(f.name, model_name) # type: ignore[call-arg]
|
||||
return AOTICompiledModel(loader)
|
||||
|
||||
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
||||
loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg]
|
||||
return AOTICompiledModel(loader)
|
||||
|
@ -2,7 +2,6 @@ import builtins
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
@ -27,6 +26,7 @@ import torch.utils._pytree as pytree
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.types import FileLike
|
||||
from torch.utils._pytree import (
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
@ -381,7 +381,7 @@ DEFAULT_PICKLE_PROTOCOL = 2
|
||||
|
||||
def save(
|
||||
ep: ExportedProgram,
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
f: FileLike,
|
||||
*,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
opset_version: Optional[dict[str, int]] = None,
|
||||
@ -399,7 +399,7 @@ def save(
|
||||
Args:
|
||||
ep (ExportedProgram): The exported program to save.
|
||||
|
||||
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
|
||||
f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to
|
||||
implement write and flush) or a string containing a file name.
|
||||
|
||||
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
|
||||
@ -464,7 +464,7 @@ def save(
|
||||
|
||||
|
||||
def load(
|
||||
f: Union[str, os.PathLike, io.BytesIO],
|
||||
f: FileLike,
|
||||
*,
|
||||
extra_files: Optional[dict[str, Any]] = None,
|
||||
expected_opset_version: Optional[dict[str, int]] = None,
|
||||
@ -479,7 +479,7 @@ def load(
|
||||
:func:`torch.export.save <torch.export.save>`.
|
||||
|
||||
Args:
|
||||
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
|
||||
f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to
|
||||
implement write and flush) or a string containing a file name.
|
||||
|
||||
extra_files (Optional[Dict[str, Any]]): The extra filenames given in
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import IO, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx import _type_utils as jit_type_utils
|
||||
@ -13,6 +13,8 @@ from torch.onnx import _type_utils as jit_type_utils
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
from torch.types import FileLike
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -117,7 +119,7 @@ def save_model_with_external_data(
|
||||
basepath: str,
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_state_dicts: tuple[dict | str | io.BytesIO, ...],
|
||||
torch_state_dicts: tuple[dict | FileLike, ...],
|
||||
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
|
||||
rename_initializer: bool = False,
|
||||
) -> None:
|
||||
@ -165,7 +167,9 @@ def save_model_with_external_data(
|
||||
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
|
||||
state_dict = el
|
||||
else:
|
||||
if isinstance(el, str) and el.endswith(".safetensors"):
|
||||
if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith(
|
||||
".safetensors"
|
||||
):
|
||||
state_dict = _convert_safetensors_to_torch_format(el)
|
||||
else:
|
||||
try:
|
||||
@ -173,14 +177,16 @@ def save_model_with_external_data(
|
||||
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
|
||||
state_dict = torch.load(el, map_location="cpu", mmap=True)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "mmap can only be used with files saved with" in str(
|
||||
e
|
||||
) or isinstance(el, io.BytesIO):
|
||||
if "mmap can only be used with files saved with" in str(e) or (
|
||||
isinstance(el, (io.IOBase, IO))
|
||||
and el.readable()
|
||||
and el.seekable()
|
||||
):
|
||||
log.warning(
|
||||
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
|
||||
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
|
||||
)
|
||||
if isinstance(el, io.BytesIO):
|
||||
if isinstance(el, (io.IOBase, IO)):
|
||||
el.seek(0) # torch.load from `try:` has read the file.
|
||||
state_dict = torch.load(el, map_location="cpu")
|
||||
else:
|
||||
|
@ -3,6 +3,7 @@ import collections
|
||||
import importlib.machinery
|
||||
import io
|
||||
import linecache
|
||||
import os
|
||||
import pickletools
|
||||
import platform
|
||||
import types
|
||||
@ -12,11 +13,11 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Callable, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, IO, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.serialization import location_tag, normalize_storage_type
|
||||
from torch.types import Storage
|
||||
from torch.types import FileLike, Storage
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from ._digraph import DiGraph
|
||||
@ -201,10 +202,10 @@ class PackageExporter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f: Union[str, Path, BinaryIO],
|
||||
f: FileLike,
|
||||
importer: Union[Importer, Sequence[Importer]] = sys_importer,
|
||||
debug: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Create an exporter.
|
||||
|
||||
@ -217,9 +218,9 @@ class PackageExporter:
|
||||
"""
|
||||
torch._C._log_api_usage_once("torch.package.PackageExporter")
|
||||
self.debug = debug
|
||||
if isinstance(f, (Path, str)):
|
||||
f = str(f)
|
||||
self.buffer: Optional[BinaryIO] = None
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
self.buffer: Optional[IO[bytes]] = None
|
||||
else: # is a byte buffer
|
||||
self.buffer = f
|
||||
|
||||
|
@ -10,11 +10,12 @@ import sys
|
||||
import types
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, BinaryIO, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
from torch.serialization import _get_restore_location, _maybe_decode_ascii
|
||||
from torch.types import FileLike
|
||||
|
||||
from ._directory_reader import DirectoryReader
|
||||
from ._importlib import (
|
||||
@ -84,7 +85,7 @@ class PackageImporter(Importer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_or_buffer: Union[str, torch._C.PyTorchFileReader, os.PathLike, BinaryIO],
|
||||
file_or_buffer: Union[FileLike, torch._C.PyTorchFileReader],
|
||||
module_allowed: Callable[[str], bool] = lambda module_name: True,
|
||||
):
|
||||
"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
|
||||
|
@ -15,7 +15,7 @@ import threading
|
||||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, BinaryIO, Callable, cast, IO, Optional, Union
|
||||
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
import torch
|
||||
@ -23,7 +23,7 @@ import torch._weights_only_unpickler as _weights_only_unpickler
|
||||
from torch._sources import get_source_lines_and_file
|
||||
from torch._utils import _import_dotted_name
|
||||
from torch.storage import _get_dtype_from_pickle_storage_type
|
||||
from torch.types import Storage
|
||||
from torch.types import FileLike, Storage
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -65,7 +65,6 @@ MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
|
||||
PROTOCOL_VERSION = 1001
|
||||
STORAGE_KEY_SEPARATOR = ","
|
||||
|
||||
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
|
||||
MAP_LOCATION: TypeAlias = Optional[
|
||||
Union[Callable[[Storage, str], Storage], torch.device, str, dict[str, str]]
|
||||
]
|
||||
@ -326,7 +325,7 @@ class safe_globals(_weights_only_unpickler._safe_globals):
|
||||
"""
|
||||
|
||||
|
||||
def get_unsafe_globals_in_checkpoint(f: FILE_LIKE) -> list[str]:
|
||||
def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
|
||||
"""Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
|
||||
|
||||
For a given function or class ``f``, the corresponding string will be of the form
|
||||
@ -702,13 +701,16 @@ def storage_to_tensor_type(storage):
|
||||
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
|
||||
|
||||
|
||||
def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]:
|
||||
def _is_path(name_or_buffer: object) -> TypeIs[Union[str, os.PathLike]]:
|
||||
return isinstance(name_or_buffer, (str, os.PathLike))
|
||||
|
||||
|
||||
class _opener:
|
||||
def __init__(self, file_like):
|
||||
self.file_like = file_like
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _opener(Generic[T]):
|
||||
def __init__(self, file_like: T) -> None:
|
||||
self.file_like: T = file_like
|
||||
|
||||
def __enter__(self):
|
||||
return self.file_like
|
||||
@ -717,26 +719,26 @@ class _opener:
|
||||
pass
|
||||
|
||||
|
||||
class _open_file(_opener):
|
||||
def __init__(self, name, mode):
|
||||
class _open_file(_opener[IO[bytes]]):
|
||||
def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None:
|
||||
super().__init__(open(name, mode))
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.file_like.close()
|
||||
|
||||
|
||||
class _open_buffer_reader(_opener):
|
||||
def __init__(self, buffer):
|
||||
class _open_buffer_reader(_opener[IO[bytes]]):
|
||||
def __init__(self, buffer: IO[bytes]) -> None:
|
||||
super().__init__(buffer)
|
||||
_check_seekable(buffer)
|
||||
|
||||
|
||||
class _open_buffer_writer(_opener):
|
||||
class _open_buffer_writer(_opener[IO[bytes]]):
|
||||
def __exit__(self, *args):
|
||||
self.file_like.flush()
|
||||
|
||||
|
||||
def _open_file_like(name_or_buffer, mode):
|
||||
def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
|
||||
if _is_path(name_or_buffer):
|
||||
return _open_file(name_or_buffer, mode)
|
||||
else:
|
||||
@ -748,15 +750,15 @@ def _open_file_like(name_or_buffer, mode):
|
||||
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
|
||||
|
||||
|
||||
class _open_zipfile_reader(_opener):
|
||||
def __init__(self, name_or_buffer) -> None:
|
||||
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
|
||||
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
|
||||
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||
|
||||
|
||||
class _open_zipfile_writer_file(_opener):
|
||||
def __init__(self, name) -> None:
|
||||
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.file_stream = None
|
||||
self.name = str(name)
|
||||
self.name = name
|
||||
try:
|
||||
self.name.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
@ -776,8 +778,8 @@ class _open_zipfile_writer_file(_opener):
|
||||
self.file_stream.close()
|
||||
|
||||
|
||||
class _open_zipfile_writer_buffer(_opener):
|
||||
def __init__(self, buffer) -> None:
|
||||
class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
|
||||
def __init__(self, buffer: IO[bytes]) -> None:
|
||||
if not callable(getattr(buffer, "write", None)):
|
||||
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
|
||||
if not hasattr(buffer, "write"):
|
||||
@ -791,7 +793,7 @@ class _open_zipfile_writer_buffer(_opener):
|
||||
self.buffer.flush()
|
||||
|
||||
|
||||
def _open_zipfile_writer(name_or_buffer):
|
||||
def _open_zipfile_writer(name_or_buffer: Union[str, IO[bytes]]) -> _opener:
|
||||
container: type[_opener]
|
||||
if _is_path(name_or_buffer):
|
||||
container = _open_zipfile_writer_file
|
||||
@ -879,7 +881,7 @@ def _check_save_filelike(f):
|
||||
|
||||
def save(
|
||||
obj: object,
|
||||
f: FILE_LIKE,
|
||||
f: FileLike,
|
||||
pickle_module: Any = pickle,
|
||||
pickle_protocol: int = DEFAULT_PROTOCOL,
|
||||
_use_new_zipfile_serialization: bool = True,
|
||||
@ -929,6 +931,9 @@ def save(
|
||||
_check_dill_version(pickle_module)
|
||||
_check_save_filelike(f)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
if _use_new_zipfile_serialization:
|
||||
with _open_zipfile_writer(f) as opened_zipfile:
|
||||
_save(
|
||||
@ -1222,7 +1227,7 @@ def _save(
|
||||
|
||||
|
||||
def load(
|
||||
f: FILE_LIKE,
|
||||
f: FileLike,
|
||||
map_location: MAP_LOCATION = None,
|
||||
pickle_module: Any = None,
|
||||
*,
|
||||
|
@ -4,6 +4,7 @@
|
||||
# top-level values. The underscore variants let us refer to these
|
||||
# types. See https://github.com/python/mypy/issues/4146 for why these
|
||||
# workarounds is necessary
|
||||
import os
|
||||
from builtins import ( # noqa: F401
|
||||
bool as _bool,
|
||||
bytes as _bytes,
|
||||
@ -13,7 +14,7 @@ from builtins import ( # noqa: F401
|
||||
str as _str,
|
||||
)
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
from typing import Any, IO, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType`
|
||||
@ -35,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from torch.autograd.graph import GradientEdge
|
||||
|
||||
|
||||
__all__ = ["Number", "Device", "Storage"]
|
||||
__all__ = ["Number", "Device", "FileLike", "Storage"]
|
||||
|
||||
# Convenience aliases for common composite types that we need
|
||||
# to talk about in PyTorch
|
||||
@ -64,6 +65,8 @@ PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool]
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number: TypeAlias = Union[int, float, bool]
|
||||
|
||||
FileLike: TypeAlias = Union[str, os.PathLike[str], IO[bytes]]
|
||||
|
||||
# Meta-type for "device-like" things. Not to be confused with 'device' (a
|
||||
# literal device object). This nomenclature is consistent with PythonArgParser.
|
||||
# None means use the default device (typically CPU)
|
||||
|
@ -6,7 +6,7 @@ import struct
|
||||
import pprint
|
||||
import zipfile
|
||||
import fnmatch
|
||||
from typing import Any, IO, BinaryIO, Union
|
||||
from typing import Any, IO
|
||||
|
||||
__all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
|
||||
|
||||
@ -119,7 +119,7 @@ def main(argv, output_stream=None):
|
||||
return 2
|
||||
|
||||
fname = argv[1]
|
||||
handle: Union[IO[bytes], BinaryIO]
|
||||
handle: IO[bytes]
|
||||
if "@" not in fname:
|
||||
with open(fname, "rb") as handle:
|
||||
DumpUnpickler.dump(handle, output_stream)
|
||||
|
Reference in New Issue
Block a user