Make torch.serialization.set_default_mmap_options usable as a context manager (#134371)

As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134371
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-08-26 19:19:32 -07:00
committed by PyTorch MergeBot
parent 0fa0ac80e4
commit 2ac710e667
2 changed files with 45 additions and 19 deletions

View File

@ -4052,7 +4052,7 @@ class TestSerialization(TestCase, SerializationMixin):
@parametrize('path_type', (str, Path))
@parametrize('weights_only', (True, False))
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_mmap_loading(self, weights_only, path_type):
def test_serialization_mmap_loading_options(self, weights_only, path_type):
class DummyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -4101,7 +4101,7 @@ class TestSerialization(TestCase, SerializationMixin):
for v in result.values():
self.assertTrue(v.is_cuda)
def test_serialization_mmap_loading_options(self):
def test_serialization_mmap_loading(self):
if IS_WINDOWS:
with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"):
torch.serialization.set_default_mmap_options(2)
@ -4111,22 +4111,36 @@ class TestSerialization(TestCase, SerializationMixin):
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
# with MmapVisibility.MAP_PRIVATE, should not be able to modify file
sd_loaded = torch.load(f.name, mmap=True)
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
sd_loaded['weight'][0][0] = 0
sd_loaded2 = torch.load(f.name, mmap=True)
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
self.assertEqual(sd_loaded2['weight'], sd['weight'])
# with MmapVisibility.MAP_SHARED, should be able to modify file
torch.serialization.set_default_mmap_options(MAP_SHARED)
try:
sd_loaded = torch.load(f.name, mmap=True)
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
sd_loaded['weight'][0][0] = 0
sd_loaded2 = torch.load(f.name, mmap=True)
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
finally:
torch.serialization.set_default_mmap_options(MAP_PRIVATE)
@unittest.skipIf(IS_WINDOWS, "mmap ctx doesn't work on Windows")
def test_serialization_mmap_loading_ctx(self):
sd = torch.nn.Linear(3, 5).state_dict()
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
with torch.serialization.set_default_mmap_options(MAP_SHARED):
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
sd_loaded['weight'][0][0] = 0
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
self.assertTrue(torch.serialization.get_default_mmap_options() == MAP_PRIVATE)
@parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32))
@parametrize('weights_only', (True, False))
def test_serialization_dtype(self, dtype, weights_only):

View File

@ -54,6 +54,8 @@ __all__ = [
"LoadEndianness",
"get_default_load_endianness",
"set_default_load_endianness",
"get_default_mmap_options",
"set_default_mmap_options",
"clear_safe_globals",
"get_safe_globals",
"add_safe_globals",
@ -163,9 +165,9 @@ def get_default_mmap_options() -> int:
return _default_mmap_options
def set_default_mmap_options(flags: int):
class set_default_mmap_options:
"""
Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
Please open an issue if you need any other option to be added here.
@ -176,17 +178,27 @@ def set_default_mmap_options(flags: int):
Args:
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
"""
global _default_mmap_options
if IS_WINDOWS:
raise RuntimeError(
"Changing the default mmap options is currently not supported for Windows"
)
if flags != MAP_PRIVATE and flags != MAP_SHARED:
raise ValueError(
"Invalid argument in function set_default_mmap_options, "
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
)
_default_mmap_options = flags
def __init__(self, flags: int) -> None:
if IS_WINDOWS:
raise RuntimeError(
"Changing the default mmap options is currently not supported for Windows"
)
if flags != MAP_PRIVATE and flags != MAP_SHARED:
raise ValueError(
"Invalid argument in function set_default_mmap_options, "
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
)
global _default_mmap_options
self.prev = _default_mmap_options
_default_mmap_options = flags
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _default_mmap_options
_default_mmap_options = self.prev
def clear_safe_globals() -> None: