mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Fix weights_only for BUILD instructions for user allowlisted objects with __slots__ (#138936)
Previously `BUILD` instruction missed handling for `__slots__`. **This only applies for things allowlisted via `add_safe_globals`/`safe_globals` that use slots.** ### Background When does pickle serialize a `BUILD` instruction? When `state` is not `None` and `state_setter` is `None` [[link](c5b99f5c2c/Lib/pickle.py (L765))]. In this case, the docs tell us that either `__setstate__` or a `__dict__` update will be performed [[link](https://github.com/python/cpython/blob/3.13/Lib/pickletools.py#L1984)] `__reduce__`/`__reduce_ex__` are expected to return tuples of length 2 to 6 where `state` is the 3rd argument. When user doesn't patch `__reduce__` but patches `__setstate__`/`__getstate__`, state will be what is yielded by `__getstate__` Note the return type for [`__getstate__` ](https://docs.python.org/3/library/pickle.html#object.__getstate__) - For a class that has no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is None. - For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and no [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is `self.__dict__`. - For a class that has an instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__) and [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__), the default state is a tuple consisting of two dictionaries: `self.__dict__`, and a dictionary mapping slot names to slot values. Only slots that have a value are included in the latter. - For a class that has [`__slots__`](https://docs.python.org/3/reference/datamodel.html#object.__slots__) and no instance [`__dict__`](https://docs.python.org/3/reference/datamodel.html#object.__dict__), the default state is a tuple whose first item is None and whose second item is a dictionary mapping slot names to slot values described in the previous bullet. see handling in pickle codec5b99f5c2c/Lib/pickle.py (L1846-L1867)Before this PR, we didn't account for the fact that when `__setstate__` is not defined, `state` might be a tuple so this would fail ```python from dataclasses import dataclass # Define the dataclass @dataclass class MyDataClass: __slots__ = ["x", "y"] x: int y: str # Create an instance of the dataclass my_data = MyDataClass(x=2, y=3) # Save the dataclass to a file torch.save(my_data, "my_data.pt") with torch.serialization.safe_globals([MyDataClass]): loaded_my_data = torch.load("my_data.pt", weights_only=True) # AttributeError: 'MyDataClass' object has no attribute '__dict__' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/138936 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2ffd41a86
commit
2a309c0997
@ -16,6 +16,7 @@ import warnings
|
||||
import zipfile
|
||||
from collections import namedtuple, OrderedDict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
@ -844,6 +845,17 @@ class ClassThatUsesBuildInstruction:
|
||||
# Third item, state here will cause pickle to push a BUILD instruction
|
||||
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
|
||||
|
||||
@dataclass
|
||||
class ClassThatUsesBuildInstructionAllSlots:
|
||||
__slots__ = ["x", "y"]
|
||||
x: int
|
||||
y: int
|
||||
|
||||
@dataclass
|
||||
class ClassThatUsesBuildInstructionSomeSlots(ClassThatUsesBuildInstructionAllSlots):
|
||||
x: int
|
||||
y: int
|
||||
c: str
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
class TestBothSerialization(TestCase):
|
||||
@ -1142,6 +1154,25 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.serialization.clear_safe_globals()
|
||||
ClassThatUsesBuildInstruction.__setstate__ = None
|
||||
|
||||
@parametrize("slots", ['some', 'all'])
|
||||
def test_weights_only_safe_globals_build_with_slots(self, slots):
|
||||
obj_cls = (
|
||||
ClassThatUsesBuildInstructionAllSlots if slots == 'all' else ClassThatUsesBuildInstructionSomeSlots
|
||||
)
|
||||
args = (2, 3) if slots == 'all' else (2, 3, 'foo')
|
||||
obj = obj_cls(*args)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(obj, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
f.seek(0)
|
||||
with torch.serialization.safe_globals([obj_cls]):
|
||||
loaded_obj = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded_obj, obj)
|
||||
|
||||
def test_weights_only_safe_globals_blocklist(self):
|
||||
module = 'nt' if IS_WINDOWS else 'posix'
|
||||
error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked"
|
||||
|
||||
@ -292,6 +292,13 @@ class Unpickler:
|
||||
elif type(inst) in _get_user_allowed_globals().values():
|
||||
if hasattr(inst, "__setstate__"):
|
||||
inst.__setstate__(state)
|
||||
elif hasattr(inst, "__slots__"):
|
||||
# if slots are defined, state will be a tuple (state, slotstate)
|
||||
state, slotstate = state
|
||||
for k, v in slotstate.items():
|
||||
setattr(inst, k, v)
|
||||
if state:
|
||||
inst.__dict__.update(state)
|
||||
else:
|
||||
inst.__dict__.update(state)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user