Refactor serialization getter/setters into torch.utils.serialization.config (#143324)

Consolidate
- get/set_default_load_endianness
- get/set_default_mmap_options
- get/set_crc32_options

into one global dynamo-style config + allow global setting of mmap. The existing APIs are not removed and will get/set from the config (as they can't be removed for BC)

In #143459 I add the local (argument style) config

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143324
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-12-20 08:12:31 -08:00
committed by PyTorch MergeBot
parent 629de988df
commit 3f63b742e6
4 changed files with 78 additions and 25 deletions

View File

@ -487,3 +487,29 @@ The following utility functions are related to serialization:
.. autofunction:: get_unsafe_globals_in_checkpoint
.. autoclass:: safe_globals
.. autoclass:: skip_data
.. _serialization config:
Config
------
.. py:module:: torch.utils.serialization
.. py:module:: torch.utils.serialization.config
``torch.utils.serialization.config`` provides a global config that can control the behavior of
``torch.save`` and ``torch.load``.
``torch.utils.serialization.config.save`` contains options that control the behavior of ``torch.save``.
* ``compute_crc32``: whether to compute and write the zip file checksum (Default : ``True``).
See :func:`~torch.serialization.set_crc32_options`.
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
* ``mmap``: See the documentation for ``mmap`` argument in :func:`torch.load`.
This config will set the behavior of ``mmap`` for ``torch.load`` if it is not
already explicitly passed to the ``torch.load`` call (Default : ``False``).
* ``endianness``: See :func:`~torch.serialization.set_default_load_endianness`.
(Default : ``torch.serialization.LoadEndianness.NATIVE``)
* ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
(Default : ``MAP_PRIVATE``)

View File

@ -140,9 +140,6 @@ class LoadEndianness(Enum):
BIG = 3
_default_load_endian: Optional[LoadEndianness] = None
def get_default_load_endianness() -> Optional[LoadEndianness]:
"""
Get fallback byte order for loading files
@ -154,7 +151,9 @@ def get_default_load_endianness() -> Optional[LoadEndianness]:
Returns:
default_load_endian: Optional[LoadEndianness]
"""
return _default_load_endian
from torch.utils.serialization import config
return config.load.endianness
def set_default_load_endianness(endianness):
@ -168,13 +167,11 @@ def set_default_load_endianness(endianness):
Args:
endianness: the new fallback byte order
"""
global _default_load_endian
if not isinstance(endianness, LoadEndianness) and endianness is not None:
raise TypeError("Invalid argument type in function set_default_load_endianness")
_default_load_endian = endianness
from torch.utils.serialization import config
_compute_crc32: bool = True
config.load.endianness = endianness
def get_crc32_options() -> bool:
@ -183,7 +180,9 @@ def get_crc32_options() -> bool:
Defaults to ``True``.
"""
return _compute_crc32
from torch.utils.serialization import config
return config.save.compute_crc32
def set_crc32_options(compute_crc32: bool):
@ -198,14 +197,12 @@ def set_crc32_options(compute_crc32: bool):
Args:
compute_crc32 (bool): set crc32 compuation flag
"""
global _compute_crc32
_compute_crc32 = compute_crc32
from torch.utils.serialization import config
config.save.compute_crc32 = compute_crc32
_default_mmap_options: int = MAP_PRIVATE
def get_default_mmap_options() -> int:
def get_default_mmap_options() -> Optional[int]:
"""
Get default mmap options for :func:`torch.load` with ``mmap=True``.
@ -215,7 +212,9 @@ def get_default_mmap_options() -> int:
Returns:
default_mmap_options: int
"""
return _default_mmap_options
from torch.utils.serialization import config
return config.load.mmap_flags
class set_default_mmap_options:
@ -242,16 +241,19 @@ class set_default_mmap_options:
"Invalid argument in function set_default_mmap_options, "
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
)
global _default_mmap_options
self.prev = _default_mmap_options
_default_mmap_options = flags
# global config
from torch.utils.serialization import config
self.prev = config.load.mmap_flags
config.load.mmap_flags = flags
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _default_mmap_options
_default_mmap_options = self.prev
from torch.utils.serialization import config
config.load.mmap_flags = self.prev
def clear_safe_globals() -> None:
@ -768,10 +770,10 @@ class _open_zipfile_writer_file(_opener):
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(
torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
)
else:
super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32))
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -787,7 +789,7 @@ class _open_zipfile_writer_buffer(_opener):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32))
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -1414,7 +1416,9 @@ def load(
# make flipping default BC-compatible
if mmap is None:
mmap = False
from torch.utils.serialization import config
mmap = config.load.mmap
_check_dill_version(pickle_module)

View File

@ -0,0 +1 @@
from . import config

View File

@ -0,0 +1,22 @@
import sys
from typing import Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING
if _TYPE_CHECKING:
from torch.serialization import LoadEndianness as _LoadEndianess
from torch.utils._config_module import install_config_module as _install_config_module
class load:
mmap: bool = False
endianness: _Optional["_LoadEndianess"] = None
# MAP_PRIVATE = 2
mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2
class save:
compute_crc32: bool = True
_install_config_module(sys.modules[__name__])