Use typing.IO[bytes] instead of io.BytesIO in annotations (#144994)

Fixes #144976

Using appoach ① `IO[bytes]`, but could also try with a protocol.

## Notes:

- moved `torch.serialization.FILE_LIKE` to `torch.types.FileLike`
- Use `FileLike` annotation where it makes sense
- made sure those functions also support `os.PathLike`
- Replaced `isinstance(x, io.BytesIO)` with `isinstance(x, (io.IOBase, IO))` where appropriate.
- Replaced `BinaryIO` with `IO[bytes]` (the two ABCs are almost identical, the only difference is that `BinaryIO` allows `bytearray` input to `write`, whereas `IO[bytes]` only `bytes`)
- needed to make `torch.serialization._opener` generic to avoid LSP violations.
- skipped `torch/onnx/verification` for now (functions use `BytesIO.getvalue` which is not part of the `IO[bytes]` ABC, but it kind of seems that this is redundant, as e.g. `onnx.load` supports `str | PathLike[str] | IO[bytes]` directly...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144994
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Randolf Scholz
2025-01-27 18:08:05 +00:00
committed by PyTorch MergeBot
parent abf28982a8
commit 835e770bad
14 changed files with 120 additions and 87 deletions

View File

@ -2,7 +2,6 @@ import builtins
import copy
import dataclasses
import inspect
import io
import os
import sys
import typing
@ -27,6 +26,7 @@ import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.types import FileLike
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
@ -381,7 +381,7 @@ DEFAULT_PICKLE_PROTOCOL = 2
def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
f: FileLike,
*,
extra_files: Optional[dict[str, Any]] = None,
opset_version: Optional[dict[str, int]] = None,
@ -399,7 +399,7 @@ def save(
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
@ -464,7 +464,7 @@ def save(
def load(
f: Union[str, os.PathLike, io.BytesIO],
f: FileLike,
*,
extra_files: Optional[dict[str, Any]] = None,
expected_opset_version: Optional[dict[str, int]] = None,
@ -479,7 +479,7 @@ def load(
:func:`torch.export.save <torch.export.save>`.
Args:
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in