mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make torch.serialization.set_default_mmap_options usable as a context manager (#134371)
As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/134371 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0fa0ac80e4
commit
2ac710e667
@ -4052,7 +4052,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
@parametrize('path_type', (str, Path))
|
||||
@parametrize('weights_only', (True, False))
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_serialization_mmap_loading(self, weights_only, path_type):
|
||||
def test_serialization_mmap_loading_options(self, weights_only, path_type):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -4101,7 +4101,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
for v in result.values():
|
||||
self.assertTrue(v.is_cuda)
|
||||
|
||||
def test_serialization_mmap_loading_options(self):
|
||||
def test_serialization_mmap_loading(self):
|
||||
if IS_WINDOWS:
|
||||
with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"):
|
||||
torch.serialization.set_default_mmap_options(2)
|
||||
@ -4111,22 +4111,36 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
# with MmapVisibility.MAP_PRIVATE, should not be able to modify file
|
||||
sd_loaded = torch.load(f.name, mmap=True)
|
||||
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||
sd_loaded['weight'][0][0] = 0
|
||||
sd_loaded2 = torch.load(f.name, mmap=True)
|
||||
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||
self.assertEqual(sd_loaded2['weight'], sd['weight'])
|
||||
# with MmapVisibility.MAP_SHARED, should be able to modify file
|
||||
torch.serialization.set_default_mmap_options(MAP_SHARED)
|
||||
try:
|
||||
sd_loaded = torch.load(f.name, mmap=True)
|
||||
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||
sd_loaded['weight'][0][0] = 0
|
||||
sd_loaded2 = torch.load(f.name, mmap=True)
|
||||
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
|
||||
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
|
||||
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
|
||||
finally:
|
||||
torch.serialization.set_default_mmap_options(MAP_PRIVATE)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "mmap ctx doesn't work on Windows")
|
||||
def test_serialization_mmap_loading_ctx(self):
|
||||
sd = torch.nn.Linear(3, 5).state_dict()
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
with torch.serialization.set_default_mmap_options(MAP_SHARED):
|
||||
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||
sd_loaded['weight'][0][0] = 0
|
||||
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
|
||||
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
|
||||
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
|
||||
self.assertTrue(torch.serialization.get_default_mmap_options() == MAP_PRIVATE)
|
||||
|
||||
@parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32))
|
||||
@parametrize('weights_only', (True, False))
|
||||
def test_serialization_dtype(self, dtype, weights_only):
|
||||
|
@ -54,6 +54,8 @@ __all__ = [
|
||||
"LoadEndianness",
|
||||
"get_default_load_endianness",
|
||||
"set_default_load_endianness",
|
||||
"get_default_mmap_options",
|
||||
"set_default_mmap_options",
|
||||
"clear_safe_globals",
|
||||
"get_safe_globals",
|
||||
"add_safe_globals",
|
||||
@ -163,9 +165,9 @@ def get_default_mmap_options() -> int:
|
||||
return _default_mmap_options
|
||||
|
||||
|
||||
def set_default_mmap_options(flags: int):
|
||||
class set_default_mmap_options:
|
||||
"""
|
||||
Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||
|
||||
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
|
||||
Please open an issue if you need any other option to be added here.
|
||||
@ -176,17 +178,27 @@ def set_default_mmap_options(flags: int):
|
||||
Args:
|
||||
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
|
||||
"""
|
||||
global _default_mmap_options
|
||||
if IS_WINDOWS:
|
||||
raise RuntimeError(
|
||||
"Changing the default mmap options is currently not supported for Windows"
|
||||
)
|
||||
if flags != MAP_PRIVATE and flags != MAP_SHARED:
|
||||
raise ValueError(
|
||||
"Invalid argument in function set_default_mmap_options, "
|
||||
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
||||
)
|
||||
_default_mmap_options = flags
|
||||
|
||||
def __init__(self, flags: int) -> None:
|
||||
if IS_WINDOWS:
|
||||
raise RuntimeError(
|
||||
"Changing the default mmap options is currently not supported for Windows"
|
||||
)
|
||||
if flags != MAP_PRIVATE and flags != MAP_SHARED:
|
||||
raise ValueError(
|
||||
"Invalid argument in function set_default_mmap_options, "
|
||||
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
||||
)
|
||||
global _default_mmap_options
|
||||
self.prev = _default_mmap_options
|
||||
_default_mmap_options = flags
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
global _default_mmap_options
|
||||
_default_mmap_options = self.prev
|
||||
|
||||
|
||||
def clear_safe_globals() -> None:
|
||||
|
Reference in New Issue
Block a user