mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Allow tensor subclasses and add torch.serialization.add_safe_globals
that allows users to allowlist classes for weights_only
load (#124331)
#### Conditions for allowlisting tensor subclasses We allow tensor subclasses types that (1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`) (2) Use the generic `tp_alloc` (3) Are in a module that *has been imported by the user* to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2` *Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution. The rationale for the 3 conditions above is as follows: The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)4e66aaa010/torch/_tensor.py (L57-L71)
`as_subclass` is implemented with a call to `THPVariable_NewWithVar` that will eventually call `tp_alloc` here4e66aaa010/torch/csrc/autograd/python_variable.cpp (L2053)
The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc` **Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling** ### How do we check something is a tensor subclass/constraints around imports In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)` This PR also allowlisted `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`) ### API for allow listing This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe). Next steps: - Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
31ea8290e7
commit
66dc8fb7ff
@ -394,3 +394,6 @@ The following utility functions are related to serialization:
|
||||
.. autofunction:: set_default_load_endianness
|
||||
.. autofunction:: get_default_mmap_options
|
||||
.. autofunction:: set_default_mmap_options
|
||||
.. autofunction:: add_safe_globals
|
||||
.. autofunction:: clear_safe_globals
|
||||
.. autofunction:: get_safe_globals
|
||||
|
@ -15,8 +15,10 @@ import pickle
|
||||
import shutil
|
||||
import pathlib
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from types import ModuleType
|
||||
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch._utils import _rebuild_tensor
|
||||
@ -27,9 +29,10 @@ from torch.serialization import check_module_version_greater_or_equal, get_defau
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName,
|
||||
TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName,
|
||||
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest)
|
||||
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.two_tensor import TwoTensor # noqa: F401
|
||||
|
||||
if not IS_WINDOWS:
|
||||
from mmap import MAP_SHARED, MAP_PRIVATE
|
||||
@ -1038,7 +1041,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertIsNone(torch.load(f, weights_only=False))
|
||||
f.seek(0)
|
||||
# Safe load should assert
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@parametrize('weights_only', (False, True))
|
||||
@ -4108,6 +4111,23 @@ class TestGetStateSubclass(torch.Tensor):
|
||||
class TestEmptySubclass(torch.Tensor):
|
||||
...
|
||||
|
||||
# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them
|
||||
# Cannot define locally in test or pickle will fail.
|
||||
class TestEmptySubclassSpoof(TestEmptySubclass):
|
||||
...
|
||||
|
||||
class TestWrapperSubclassSpoof(TestWrapperSubclass):
|
||||
...
|
||||
|
||||
class RebuildFromTypeV2Spoof(torch.Tensor):
|
||||
def __new__(cls, elem, naughty, **kwargs):
|
||||
if naughty:
|
||||
raise RuntimeError("naughty")
|
||||
return super().__new__(cls, elem)
|
||||
|
||||
def __reduce_ex__(self, protocol):
|
||||
return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {}))
|
||||
|
||||
|
||||
class TestSubclassSerialization(TestCase):
|
||||
def test_tensor_subclass_wrapper_serialization(self):
|
||||
@ -4187,6 +4207,203 @@ class TestSubclassSerialization(TestCase):
|
||||
f.seek(0)
|
||||
tensor2 = torch.load(f)
|
||||
|
||||
def _create_bad_func(self, name):
|
||||
def bad_func(self, *args, **kwargs):
|
||||
raise RuntimeError(f"running {name}")
|
||||
return bad_func
|
||||
|
||||
@parametrize("wrapper", (True, False))
|
||||
def test_tensor_subclass_method_spoofing(self, wrapper):
|
||||
'''
|
||||
This tests seeks to do the following:
|
||||
- determine which methods of a tensor subclass might be called during unpickling (weights_only=False)
|
||||
we consider these methods "risky" for weights_only
|
||||
- ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True)
|
||||
- ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True)
|
||||
|
||||
We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError
|
||||
when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that
|
||||
only the RuntimeErrors that we expect are thrown.
|
||||
|
||||
We then load with weights_only and ensure that weights_only will fail unless all the risky methods
|
||||
are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load.
|
||||
The final weights_only load call when all the risky methods are no longer overriden.
|
||||
'''
|
||||
subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof
|
||||
t = subclass(torch.randn(2, 3))
|
||||
# To trigger setattr for the non-wrapper case
|
||||
if not wrapper:
|
||||
t.foo = 'bar'
|
||||
inp = {'weight': t}
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(inp, f)
|
||||
loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded['weight'], inp['weight'])
|
||||
|
||||
restore_methods = dict()
|
||||
methods = [func for func in dir(subclass) if callable(getattr(subclass, func))]
|
||||
for method in methods:
|
||||
if method != "__class__":
|
||||
restore_methods[method] = getattr(subclass, method)
|
||||
setattr(subclass, method, self._create_bad_func(method))
|
||||
# These additional methods might be called during getattr or setattr
|
||||
# but are not in methods above (not defined on tensor base class)
|
||||
subclass.__get__ = self._create_bad_func("__get__")
|
||||
subclass.__set__ = self._create_bad_func("__set__")
|
||||
subclass.__getattr__ = self._create_bad_func("__getattr__")
|
||||
restore_methods["__get__"] = None
|
||||
restore_methods["__getattr__"] = None
|
||||
restore_methods["__set__"] = None
|
||||
|
||||
try:
|
||||
# Check that weights_only=False load raises the RuntimeErrors we expect
|
||||
with self.assertRaisesRegex(RuntimeError, "running __getattribute__"):
|
||||
torch.load(f, weights_only=False)
|
||||
subclass.__getattribute__ = restore_methods['__getattribute__']
|
||||
with self.assertRaisesRegex(RuntimeError, "running __setstate__"):
|
||||
torch.load(f, weights_only=False)
|
||||
subclass.__setstate__ = restore_methods['__setstate__']
|
||||
with self.assertRaisesRegex(RuntimeError, "running __setattr__"):
|
||||
torch.load(f, weights_only=False)
|
||||
subclass.__setattr__ = restore_methods['__setattr__']
|
||||
# should finally work
|
||||
torch.load(f, weights_only=False)
|
||||
|
||||
# Check that weights_only=True catches that risky methods are overriden
|
||||
subclass.__setstate__ = self._create_bad_func("__setstate__")
|
||||
subclass.__getattribute__ = self._create_bad_func("__getattribute__")
|
||||
subclass.__setattr__ = self._create_bad_func("__setattr__")
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"methods: __getattribute__=True __getattr__=True __get__=True "
|
||||
"__setattr__=True __set__=True __setstate__=True"):
|
||||
torch.load(f, weights_only=True)
|
||||
risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__']
|
||||
for i, meth in enumerate(risky_methods):
|
||||
setattr(subclass, meth, restore_methods[meth])
|
||||
if i != len(risky_methods) - 1:
|
||||
# When the given methods are not all back to default, load should still throw
|
||||
# but reflect which methods are no longer overriden
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"):
|
||||
torch.load(f, weights_only=True)
|
||||
else:
|
||||
# When the given methods are all back to default, weights_only load should finally work
|
||||
loaded = torch.load(f, weights_only=True)
|
||||
finally:
|
||||
for method, func in restore_methods.items():
|
||||
setattr(subclass, method, func)
|
||||
a = subclass(torch.randn(2, 3))
|
||||
|
||||
@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
|
||||
def test_safe_globals_for_weights_only(self):
|
||||
'''
|
||||
Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs
|
||||
'''
|
||||
# Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment
|
||||
global TwoTensor
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
p = torch.nn.Parameter(t)
|
||||
sd = OrderedDict([('t', t), ('p', p)])
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
# unimport TwoTensor
|
||||
try:
|
||||
del sys.modules['torch.testing._internal.two_tensor']
|
||||
|
||||
# Loading tensor subclass with weights_only=True should fail
|
||||
# if tensor subclass has not been imported
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"):
|
||||
f.seek(0)
|
||||
sd = torch.load(f, weights_only=True)
|
||||
|
||||
# Loading tensor subclass with weights_only=True should work
|
||||
# if target methods are not overriden and user has imported the subclass
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
f.seek(0)
|
||||
sd = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd['t'], t)
|
||||
self.assertEqual(sd['p'], p)
|
||||
|
||||
# Loading tensor subclass with weights_only=True should fail
|
||||
# if __setstate__ is overriden
|
||||
f.seek(0)
|
||||
restore_setstate = TwoTensor.__setstate__
|
||||
try:
|
||||
TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
# Loading tensor subclass with overriden __setstate__ with weights_only=True should work
|
||||
# if the class is marked safe
|
||||
f.seek(0)
|
||||
torch.serialization.add_safe_globals([TwoTensor])
|
||||
self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor])
|
||||
sd = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd['t'], t)
|
||||
self.assertEqual(sd['p'], p)
|
||||
|
||||
# Should fail again when safe globals are cleared
|
||||
torch.serialization.clear_safe_globals()
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"):
|
||||
torch.load(f, weights_only=True)
|
||||
finally:
|
||||
TwoTensor.__setstate__ = restore_setstate
|
||||
finally:
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
def test_tensor_subclass_parent_module_method_spoofing(self):
|
||||
'''
|
||||
Tests that weights_only load does not call any methods of the parent module
|
||||
that contains the tensor subclass.
|
||||
|
||||
We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError
|
||||
when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that
|
||||
no RuntimeErrors are thrown.
|
||||
'''
|
||||
# Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass`
|
||||
class SpoofModule(ModuleType):
|
||||
pass
|
||||
|
||||
spoof_mod = SpoofModule('bla')
|
||||
spoof_mod.TestEmptySubclass = TestEmptySubclass
|
||||
inp = {'weight': TestEmptySubclass(torch.randn(2, 3))}
|
||||
TestEmptySubclass.__module__ = 'spoof_mod'
|
||||
sys.modules['spoof_mod'] = spoof_mod
|
||||
|
||||
try:
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(inp, f)
|
||||
torch.load(f, weights_only=True)
|
||||
restore_methods = dict()
|
||||
methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))]
|
||||
for method in methods:
|
||||
if method != "__class__":
|
||||
restore_methods[method] = getattr(SpoofModule, method)
|
||||
setattr(SpoofModule, method, self._create_bad_func(method))
|
||||
SpoofModule.__get__ = self._create_bad_func("__get__")
|
||||
SpoofModule.__getattr__ = self._create_bad_func("__getattr__")
|
||||
loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(loaded['weight'], inp['weight'])
|
||||
finally:
|
||||
TestEmptySubclass.__module__ = __name__
|
||||
del sys.modules['spoof_mod']
|
||||
|
||||
def test_rebuild_from_type_v2_spoof(self):
|
||||
t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False)
|
||||
inp = {'weight': t}
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(inp, f)
|
||||
# subclass will be pushed onto unpickler's stack as a string
|
||||
# and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2
|
||||
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
|
||||
loaded = torch.load(f, weights_only=True)
|
||||
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||
instantiate_parametrized_tests(TestSubclassSerialization)
|
||||
|
@ -1196,6 +1196,7 @@ def _has_storage(x: Tensor) -> _bool: ...
|
||||
def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
|
||||
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
|
||||
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, str], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
|
||||
def _check_tp_alloc_is_default(cls: Type) -> _bool: ...
|
||||
|
||||
# NB: There is no Capsule type in typing, see
|
||||
# https://code.activestate.com/lists/python-dev/139675/
|
||||
|
@ -9,6 +9,10 @@
|
||||
# - `torch.nn.Parameter`
|
||||
# - `collections.Counter`
|
||||
# - `collections.OrderedDict`
|
||||
# Additionally, users can use an allowlist for adding classes they have deemed as safe using
|
||||
# `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
|
||||
# `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
|
||||
# `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
|
||||
|
||||
# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
|
||||
# Expected to be useful for loading PyTorch model weights
|
||||
@ -19,6 +23,7 @@
|
||||
|
||||
import functools as _functools
|
||||
from collections import Counter, OrderedDict
|
||||
from inspect import getattr_static
|
||||
from pickle import (
|
||||
APPEND,
|
||||
APPENDS,
|
||||
@ -59,11 +64,57 @@ from pickle import (
|
||||
UnpicklingError,
|
||||
)
|
||||
from struct import unpack
|
||||
from sys import maxsize
|
||||
from typing import Any, Dict, List
|
||||
from sys import maxsize, modules
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import torch
|
||||
|
||||
_marked_safe_globals_list: List[Any] = []
|
||||
|
||||
|
||||
def _add_safe_globals(safe_globals: List[Any]):
|
||||
global _marked_safe_globals_list
|
||||
_marked_safe_globals_list += safe_globals
|
||||
|
||||
|
||||
def _get_safe_globals() -> List[Any]:
|
||||
global _marked_safe_globals_list
|
||||
return _marked_safe_globals_list
|
||||
|
||||
|
||||
def _clear_safe_globals():
|
||||
global _marked_safe_globals_list
|
||||
_marked_safe_globals_list = []
|
||||
|
||||
|
||||
# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
|
||||
# For example if user had a script like
|
||||
# torch.load(file_a)
|
||||
# torch.serialization._add_safe_globals([torch.foo])
|
||||
# torch.load(file_b)
|
||||
# the dynamic additions to safe_globals would not be picked up by
|
||||
# _get_allowed_globals due to the lru_cache
|
||||
def _get_user_allowed_globals():
|
||||
rc: Dict[str, Any] = {}
|
||||
for f in _marked_safe_globals_list:
|
||||
rc[f"{f.__module__}.{f.__name__}"] = f
|
||||
return rc
|
||||
|
||||
|
||||
def _tensor_rebuild_functions():
|
||||
return {
|
||||
torch._utils._rebuild_parameter,
|
||||
torch._utils._rebuild_parameter_with_state,
|
||||
torch._utils._rebuild_qtensor,
|
||||
torch._utils._rebuild_tensor,
|
||||
torch._utils._rebuild_tensor_v2,
|
||||
torch._utils._rebuild_tensor_v3,
|
||||
torch._utils._rebuild_sparse_tensor,
|
||||
torch._utils._rebuild_meta_tensor_no_storage,
|
||||
torch._utils._rebuild_nested_tensor,
|
||||
torch._utils._rebuild_wrapper_subclass,
|
||||
}
|
||||
|
||||
|
||||
# Unpickling machinery
|
||||
@_functools.lru_cache(maxsize=1)
|
||||
@ -75,6 +126,7 @@ def _get_allowed_globals():
|
||||
"torch.serialization._get_layout": torch.serialization._get_layout,
|
||||
"torch.Size": torch.Size,
|
||||
"torch.Tensor": torch.Tensor,
|
||||
"torch.device": torch.device,
|
||||
}
|
||||
# dtype
|
||||
for t in torch.storage._dtype_to_storage_type_map().keys():
|
||||
@ -103,17 +155,7 @@ def _get_allowed_globals():
|
||||
]:
|
||||
rc[str(qt)] = qt
|
||||
# Rebuild functions
|
||||
for f in [
|
||||
torch._utils._rebuild_parameter,
|
||||
torch._utils._rebuild_parameter_with_state,
|
||||
torch._utils._rebuild_qtensor,
|
||||
torch._utils._rebuild_tensor,
|
||||
torch._utils._rebuild_tensor_v2,
|
||||
torch._utils._rebuild_tensor_v3,
|
||||
torch._utils._rebuild_sparse_tensor,
|
||||
torch._utils._rebuild_meta_tensor_no_storage,
|
||||
torch._utils._rebuild_nested_tensor,
|
||||
]:
|
||||
for f in _tensor_rebuild_functions():
|
||||
rc[f"torch._utils.{f.__name__}"] = f
|
||||
|
||||
# Handles Tensor Subclasses, Tensor's with attributes.
|
||||
@ -128,6 +170,11 @@ class Unpickler:
|
||||
self.readline = file.readline
|
||||
self.read = file.read
|
||||
self.memo: Dict[int, Any] = {}
|
||||
# tensor subclass types found from GLOBAL instructions that have passed the criteria
|
||||
# to be allowed as the second argument to `torch._tensor._rebuild_from_type_v2`
|
||||
# This enables rebuilding of tensor subclasses defined outside the `torch` package.
|
||||
# See [Note: Criteria for allowing out-of-core tensor subclasses] for details on the criteria.
|
||||
self.tensor_subclasses_found: Dict[str, Type] = {}
|
||||
|
||||
def load(self):
|
||||
"""Read a pickled object representation from the open file.
|
||||
@ -151,8 +198,124 @@ class Unpickler:
|
||||
full_path = f"{module}.{name}"
|
||||
if full_path in _get_allowed_globals():
|
||||
self.append(_get_allowed_globals()[full_path])
|
||||
elif full_path in _get_user_allowed_globals():
|
||||
self.append(_get_user_allowed_globals()[full_path])
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported class {full_path}")
|
||||
# The logic in this branch handles user-defined tensor subclasses.
|
||||
# We can automatically allow and raise and error for anything that is not provably safe.
|
||||
# [Note: Criteria for allowing out-of-core tensor subclasses]
|
||||
# GLOBAL '<module>.<tensor subclass>' instructions will get the class and
|
||||
# push the string (not the actual type) while adding the type to the dictionary keyed
|
||||
# by the string onto the unpickler's stack if they satisfy the following conditions:
|
||||
# (1) The <module> that defines them is in `sys.modules`
|
||||
# (we will use getattr_static to access it to ensure no code execution)
|
||||
# (2) They inherit from `torch.Tensor`
|
||||
# (2) The class is not overriding any of the `torch.Tensor` methods listed here:
|
||||
# `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, `__set__`,
|
||||
# and `tp_alloc`
|
||||
# The methods that we ban overriding were selected in a test-driven manner
|
||||
# by overriding every callable method on a tensor subclass and determinining
|
||||
# which might get called during unpickling.
|
||||
# When executing REDUCE, the string will be appropriately converted back to the type only
|
||||
# for `torch._tensor._rebuild_from_type_v2` as other use of the class could use methods
|
||||
# we didn't audit.
|
||||
if module == "__builtin__":
|
||||
raise RuntimeError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
"Please use `torch.serialization.add_safe_globals` to allowlist this global "
|
||||
"if you trust this class/function."
|
||||
)
|
||||
elif module not in modules:
|
||||
# TODO: add a link here to a doc that explains to users what we mean by trust
|
||||
raise RuntimeError(
|
||||
f"Found GLOBAL `{full_path}` instruction in the pickle file but `{full_path}` was "
|
||||
f"not in the pre-defined list of allowed globals that are considered safe by the "
|
||||
"weights_only unpickler for rebuilding state_dicts. This is the expected behavior if "
|
||||
f"`{full_path}` is a class or function that is not in the list of allowed globals "
|
||||
f"If `{full_path}` is NOT a tensor subclass, you might consider"
|
||||
"`torch.serialization.add_safe_globals` if it is appropriate. However, if it is a "
|
||||
"user-defined tensor subclass not defined in the `torch` package, this error might arise "
|
||||
f"as we expect `{module}` to be present in `sys.modules` (i.e. it "
|
||||
"must be imported in the current environment), but this was not the case. "
|
||||
f"If you intend to unpickle a tensor subclass `{full_path}` please import `{name}` from "
|
||||
f"`{module}`. Note that having this imported will *only* allow the type `{full_path}` to "
|
||||
"be passed as the second argument to `torch._tensor._rebuild_from_type_v2`, which should "
|
||||
"enable the tensor subclass to be unpickled without any arbitrary code execution as long "
|
||||
# If the user imports and these are overridden the next error will prompt them to use
|
||||
# torch.serialization.add_safe_globals.
|
||||
"a sa pre-defined list of methods called when unpickling are not overridden. In "
|
||||
"particular, the methods are `__getattr__`, `__get__`, `__getattribute__`, `__setstate__`, "
|
||||
"`__set__`, as well as the implementation of `tp_alloc`."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
class_type = getattr_static(modules[module], name)
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
"For safety during weights_only loading, we use inspect.getattr_state to "
|
||||
f"get {name} from {module}, if {module} implements the descriptor protocol, "
|
||||
"__getattr__ or __getattribute__ these will not be called."
|
||||
) from e
|
||||
# None of the objects here contain any data from the pickle so this is safe
|
||||
if isinstance(class_type, type) and issubclass(
|
||||
class_type, torch.Tensor
|
||||
):
|
||||
# getattr is called by the getattr call in `_rebuild_from_type_v2`
|
||||
custom_get_attribute = (
|
||||
class_type.__getattribute__
|
||||
is not torch.Tensor.__getattribute__
|
||||
)
|
||||
custom_get = (
|
||||
getattr_static(class_type, "__get__", None) is not None
|
||||
)
|
||||
custom_get_attr = (
|
||||
getattr_static(class_type, "__getattr__", None)
|
||||
is not None
|
||||
)
|
||||
# Tensor.__setstate__ might be called in `_rebuild_from_type_v2`
|
||||
custom_set_state = (
|
||||
class_type.__setstate__ is not torch.Tensor.__setstate__
|
||||
)
|
||||
# setattr is called in `torch._utils._set_obj_state`
|
||||
custom_set_attr = (
|
||||
class_type.__setattr__ is not object.__setattr__
|
||||
)
|
||||
custom_set = (
|
||||
getattr_static(class_type, "__set__", None) is not None
|
||||
)
|
||||
# tp_alloc is called by `Tensor._rebuild_wrapper_subclass` and `Tensor.as_subclass`
|
||||
has_custom_tp_alloc = (
|
||||
not torch._C._check_tp_alloc_is_default(class_type)
|
||||
)
|
||||
custom_methods = {
|
||||
"__getattribute__": custom_get_attribute,
|
||||
"__getattr__": custom_get_attr,
|
||||
"__get__": custom_get,
|
||||
"__setattr__": custom_set_attr,
|
||||
"__set__": custom_set,
|
||||
"__setstate__": custom_set_state,
|
||||
"tp_alloc": has_custom_tp_alloc,
|
||||
}
|
||||
if any(custom_methods.values()):
|
||||
error = ""
|
||||
for k, v in custom_methods.items():
|
||||
error += f" {k}={v}"
|
||||
raise RuntimeError(
|
||||
f"Trying to unpickle tensor subclass `{full_path}` that has defined a custom "
|
||||
f"version for one of these methods:{error}. Please check whether you trust these "
|
||||
"methods and allowlist the subclass with `torch.serialization.add_safe_globals` if so."
|
||||
)
|
||||
# push the string full_path onto the stack (in REBUILD, there is special logic to
|
||||
# access this from tensor_subclasses_found for rebuild_from_type_v2)
|
||||
self.tensor_subclasses_found[full_path] = class_type
|
||||
self.append(full_path)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
"Please use `torch.serialization.add_safe_globals` to allowlist this global "
|
||||
"if you trust this class/function."
|
||||
)
|
||||
|
||||
elif key[0] == NEWOBJ[0]:
|
||||
args = self.stack.pop()
|
||||
cls = self.stack.pop()
|
||||
@ -162,10 +325,33 @@ class Unpickler:
|
||||
elif key[0] == REDUCE[0]:
|
||||
args = self.stack.pop()
|
||||
func = self.stack[-1]
|
||||
if func not in _get_allowed_globals().values():
|
||||
if (
|
||||
func not in _get_allowed_globals().values()
|
||||
and func not in _get_user_allowed_globals().values()
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Trying to call reduce for unrecognized function {func}"
|
||||
)
|
||||
# Special handling for tensor subclass type found in GLOBAL that is pushed
|
||||
# onto stack as str to prevent it from being used anywhere except the
|
||||
# second arg of _rebuild_from_type_v2 and within argument tuple for _rebuild_wrapper_subclass
|
||||
# _rebuild_from_type_v2 is called with args (func, type, func_args, state)
|
||||
# where both type and, when func is rebuild_wrapper_subclass, func_args[0] could be the subclass type
|
||||
# Since we pushed these subclass types onto the stack as strings, convert them to the actual
|
||||
# type here.
|
||||
if func is torch._tensor._rebuild_from_type_v2 and type(args[1]) is str:
|
||||
args_after = args[2:]
|
||||
if (
|
||||
args[0] is torch._utils._rebuild_wrapper_subclass
|
||||
and type(args[2][0]) is str
|
||||
):
|
||||
new_arg_tuple = (
|
||||
self.tensor_subclasses_found[args[2][0]],
|
||||
) + args[2][1:]
|
||||
args_after = (new_arg_tuple,) + args[3:]
|
||||
args = (
|
||||
args[:1] + (self.tensor_subclasses_found[args[1]],) + args_after
|
||||
)
|
||||
self.stack[-1] = func(*args)
|
||||
elif key[0] == BUILD[0]:
|
||||
state = self.stack.pop()
|
||||
|
@ -422,6 +422,19 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_check_tp_alloc_is_default(
|
||||
PyObject* _unused,
|
||||
PyObject* cls) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
return PyBool_FromLong(Py_TYPE(cls)->tp_alloc == PyType_GenericAlloc);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
|
||||
// adds a __doc__ string to a function, similar to numpy's arr_add_docstring
|
||||
static std::vector<std::string> all_docs;
|
||||
@ -1268,6 +1281,10 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
||||
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
|
||||
{"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr},
|
||||
{"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr},
|
||||
{"_check_tp_alloc_is_default",
|
||||
THPModule_check_tp_alloc_is_default,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_init_names", THPModule_initNames, METH_O, nullptr},
|
||||
{"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr},
|
||||
{"_set_default_tensor_type",
|
||||
|
@ -59,6 +59,9 @@ __all__ = [
|
||||
'LoadEndianness',
|
||||
'get_default_load_endianness',
|
||||
'set_default_load_endianness',
|
||||
'clear_safe_globals',
|
||||
'get_safe_globals',
|
||||
'add_safe_globals',
|
||||
]
|
||||
|
||||
|
||||
@ -148,6 +151,27 @@ def set_default_mmap_options(flags: int):
|
||||
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
|
||||
_default_mmap_options = flags
|
||||
|
||||
def clear_safe_globals() -> None:
|
||||
'''
|
||||
Clears the list of globals that are safe for ``weights_only`` load.
|
||||
'''
|
||||
_weights_only_unpickler._clear_safe_globals()
|
||||
|
||||
def get_safe_globals() -> List[Any]:
|
||||
'''
|
||||
Returns the list of user-added globals that are safe for ``weights_only`` load.
|
||||
'''
|
||||
return _weights_only_unpickler._get_safe_globals()
|
||||
|
||||
def add_safe_globals(safe_globals: List[Any]) -> None:
|
||||
'''
|
||||
Marks the given globals as safe for ``weights_only`` load.
|
||||
|
||||
Args:
|
||||
safe_globals (List[Any]): list of globals to mark as safe
|
||||
'''
|
||||
_weights_only_unpickler._add_safe_globals(safe_globals)
|
||||
|
||||
def _is_zipfile(f) -> bool:
|
||||
# This is a stricter implementation than zipfile.is_zipfile().
|
||||
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
||||
@ -952,7 +976,9 @@ def load(
|
||||
UNSAFE_MESSAGE = (
|
||||
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
|
||||
" will likely succeed, but it can result in arbitrary code execution."
|
||||
"Do it only if you get the file from a trusted source. WeightsUnpickler error: "
|
||||
" Do it only if you get the file from a trusted source. Alternatively, to load"
|
||||
" with `weights_only` please check the recommended steps in the following error message."
|
||||
" WeightsUnpickler error: "
|
||||
)
|
||||
# Add ability to force safe only weight loads via environment variable
|
||||
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
|
||||
|
Reference in New Issue
Block a user