mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Refactor serialization getter/setters into torch.utils.serialization.config (#143324)
Consolidate - get/set_default_load_endianness - get/set_default_mmap_options - get/set_crc32_options into one global dynamo-style config + allow global setting of mmap. The existing APIs are not removed and will get/set from the config (as they can't be removed for BC) In #143459 I add the local (argument style) config Pull Request resolved: https://github.com/pytorch/pytorch/pull/143324 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
629de988df
commit
3f63b742e6
@ -487,3 +487,29 @@ The following utility functions are related to serialization:
|
||||
.. autofunction:: get_unsafe_globals_in_checkpoint
|
||||
.. autoclass:: safe_globals
|
||||
.. autoclass:: skip_data
|
||||
|
||||
.. _serialization config:
|
||||
|
||||
Config
|
||||
------
|
||||
.. py:module:: torch.utils.serialization
|
||||
.. py:module:: torch.utils.serialization.config
|
||||
|
||||
``torch.utils.serialization.config`` provides a global config that can control the behavior of
|
||||
``torch.save`` and ``torch.load``.
|
||||
|
||||
|
||||
``torch.utils.serialization.config.save`` contains options that control the behavior of ``torch.save``.
|
||||
|
||||
* ``compute_crc32``: whether to compute and write the zip file checksum (Default : ``True``).
|
||||
See :func:`~torch.serialization.set_crc32_options`.
|
||||
|
||||
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
|
||||
|
||||
* ``mmap``: See the documentation for ``mmap`` argument in :func:`torch.load`.
|
||||
This config will set the behavior of ``mmap`` for ``torch.load`` if it is not
|
||||
already explicitly passed to the ``torch.load`` call (Default : ``False``).
|
||||
* ``endianness``: See :func:`~torch.serialization.set_default_load_endianness`.
|
||||
(Default : ``torch.serialization.LoadEndianness.NATIVE``)
|
||||
* ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
|
||||
(Default : ``MAP_PRIVATE``)
|
||||
|
||||
@ -140,9 +140,6 @@ class LoadEndianness(Enum):
|
||||
BIG = 3
|
||||
|
||||
|
||||
_default_load_endian: Optional[LoadEndianness] = None
|
||||
|
||||
|
||||
def get_default_load_endianness() -> Optional[LoadEndianness]:
|
||||
"""
|
||||
Get fallback byte order for loading files
|
||||
@ -154,7 +151,9 @@ def get_default_load_endianness() -> Optional[LoadEndianness]:
|
||||
Returns:
|
||||
default_load_endian: Optional[LoadEndianness]
|
||||
"""
|
||||
return _default_load_endian
|
||||
from torch.utils.serialization import config
|
||||
|
||||
return config.load.endianness
|
||||
|
||||
|
||||
def set_default_load_endianness(endianness):
|
||||
@ -168,13 +167,11 @@ def set_default_load_endianness(endianness):
|
||||
Args:
|
||||
endianness: the new fallback byte order
|
||||
"""
|
||||
global _default_load_endian
|
||||
if not isinstance(endianness, LoadEndianness) and endianness is not None:
|
||||
raise TypeError("Invalid argument type in function set_default_load_endianness")
|
||||
_default_load_endian = endianness
|
||||
from torch.utils.serialization import config
|
||||
|
||||
|
||||
_compute_crc32: bool = True
|
||||
config.load.endianness = endianness
|
||||
|
||||
|
||||
def get_crc32_options() -> bool:
|
||||
@ -183,7 +180,9 @@ def get_crc32_options() -> bool:
|
||||
|
||||
Defaults to ``True``.
|
||||
"""
|
||||
return _compute_crc32
|
||||
from torch.utils.serialization import config
|
||||
|
||||
return config.save.compute_crc32
|
||||
|
||||
|
||||
def set_crc32_options(compute_crc32: bool):
|
||||
@ -198,14 +197,12 @@ def set_crc32_options(compute_crc32: bool):
|
||||
Args:
|
||||
compute_crc32 (bool): set crc32 compuation flag
|
||||
"""
|
||||
global _compute_crc32
|
||||
_compute_crc32 = compute_crc32
|
||||
from torch.utils.serialization import config
|
||||
|
||||
config.save.compute_crc32 = compute_crc32
|
||||
|
||||
|
||||
_default_mmap_options: int = MAP_PRIVATE
|
||||
|
||||
|
||||
def get_default_mmap_options() -> int:
|
||||
def get_default_mmap_options() -> Optional[int]:
|
||||
"""
|
||||
Get default mmap options for :func:`torch.load` with ``mmap=True``.
|
||||
|
||||
@ -215,7 +212,9 @@ def get_default_mmap_options() -> int:
|
||||
Returns:
|
||||
default_mmap_options: int
|
||||
"""
|
||||
return _default_mmap_options
|
||||
from torch.utils.serialization import config
|
||||
|
||||
return config.load.mmap_flags
|
||||
|
||||
|
||||
class set_default_mmap_options:
|
||||
@ -242,16 +241,19 @@ class set_default_mmap_options:
|
||||
"Invalid argument in function set_default_mmap_options, "
|
||||
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
||||
)
|
||||
global _default_mmap_options
|
||||
self.prev = _default_mmap_options
|
||||
_default_mmap_options = flags
|
||||
# global config
|
||||
from torch.utils.serialization import config
|
||||
|
||||
self.prev = config.load.mmap_flags
|
||||
config.load.mmap_flags = flags
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
global _default_mmap_options
|
||||
_default_mmap_options = self.prev
|
||||
from torch.utils.serialization import config
|
||||
|
||||
config.load.mmap_flags = self.prev
|
||||
|
||||
|
||||
def clear_safe_globals() -> None:
|
||||
@ -768,10 +770,10 @@ class _open_zipfile_writer_file(_opener):
|
||||
# for writing out the file.
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)
|
||||
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
|
||||
)
|
||||
else:
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32))
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -787,7 +789,7 @@ class _open_zipfile_writer_buffer(_opener):
|
||||
raise AttributeError(msg)
|
||||
raise TypeError(msg)
|
||||
self.buffer = buffer
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32))
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -1414,7 +1416,9 @@ def load(
|
||||
|
||||
# make flipping default BC-compatible
|
||||
if mmap is None:
|
||||
mmap = False
|
||||
from torch.utils.serialization import config
|
||||
|
||||
mmap = config.load.mmap
|
||||
|
||||
_check_dill_version(pickle_module)
|
||||
|
||||
|
||||
1
torch/utils/serialization/__init__.py
Normal file
1
torch/utils/serialization/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import config
|
||||
22
torch/utils/serialization/config.py
Normal file
22
torch/utils/serialization/config.py
Normal file
@ -0,0 +1,22 @@
|
||||
import sys
|
||||
from typing import Optional as _Optional, TYPE_CHECKING as _TYPE_CHECKING
|
||||
|
||||
|
||||
if _TYPE_CHECKING:
|
||||
from torch.serialization import LoadEndianness as _LoadEndianess
|
||||
|
||||
from torch.utils._config_module import install_config_module as _install_config_module
|
||||
|
||||
|
||||
class load:
|
||||
mmap: bool = False
|
||||
endianness: _Optional["_LoadEndianess"] = None
|
||||
# MAP_PRIVATE = 2
|
||||
mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2
|
||||
|
||||
|
||||
class save:
|
||||
compute_crc32: bool = True
|
||||
|
||||
|
||||
_install_config_module(sys.modules[__name__])
|
||||
Reference in New Issue
Block a user