mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Blocklist certain modules for weights_only load (#131259)
Also bold certain text in the error message as suggested <img width="3000" alt="Screenshot 2024-07-19 at 5 56 48 PM" src="https://github.com/user-attachments/assets/378f20c5-c6b2-4e53-8eaf-0bd26c3a6b60"> With a GLOBAL like `os.execv` the error message is now as such ```python File "/data/users/mg1998/pytorch/torch/serialization.py", line 1256, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None _pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. Trying to load unsupported GLOBAL posix.execv whose module posix is blocked. Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131259 Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
93ef2e53f8
commit
d3556786b8
@ -1130,6 +1130,20 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.serialization.clear_safe_globals()
|
||||
ClassThatUsesBuildInstruction.__setstate__ = None
|
||||
|
||||
def test_weights_only_safe_globals_blocklist(self):
|
||||
module = 'nt' if IS_WINDOWS else 'posix'
|
||||
error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked"
|
||||
with BytesIOContext() as f:
|
||||
torch.save(os.execv, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
|
||||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
# safe_globals doesn't work even with allowlist
|
||||
with torch.serialization.safe_globals([os.execv]):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@parametrize("unsafe_global", [True, False])
|
||||
def test_weights_only_error(self, unsafe_global):
|
||||
sd = {'t': TwoTensor(torch.randn(2), torch.randn(2))}
|
||||
|
@ -72,6 +72,15 @@ import torch
|
||||
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
||||
|
||||
|
||||
# modules in this list are never allowed, even if the user attempts to allowlist
|
||||
# functions/classes from them
|
||||
_blocklisted_modules = [
|
||||
"sys",
|
||||
"os",
|
||||
"posix",
|
||||
"nt",
|
||||
]
|
||||
|
||||
_marked_safe_globals_list: List[Any] = []
|
||||
|
||||
|
||||
@ -221,6 +230,10 @@ class Unpickler:
|
||||
elif module in IMPORT_MAPPING:
|
||||
module = IMPORT_MAPPING[module]
|
||||
full_path = f"{module}.{name}"
|
||||
if module in _blocklisted_modules:
|
||||
raise RuntimeError(
|
||||
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
|
||||
)
|
||||
if full_path in _get_allowed_globals():
|
||||
self.append(_get_allowed_globals()[full_path])
|
||||
elif full_path in _get_user_allowed_globals():
|
||||
|
@ -1145,21 +1145,26 @@ def load(
|
||||
)
|
||||
|
||||
def _get_wo_message(message: str) -> str:
|
||||
pattern = r"GLOBAL (\S+) was not an allowed global by default."
|
||||
has_unsafe_global = re.search(pattern, message) is not None
|
||||
unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
|
||||
has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
|
||||
blocklist_pattern = r"whose module (\S+) is blocked"
|
||||
has_blocklist = re.search(blocklist_pattern, message) is not None
|
||||
if has_unsafe_global:
|
||||
updated_message = (
|
||||
"Weights only load failed. This file can still be loaded, to do so you have two options "
|
||||
"Weights only load failed. This file can still be loaded, to do so you have two options, "
|
||||
"\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
|
||||
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
|
||||
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
|
||||
+ message
|
||||
)
|
||||
else:
|
||||
updated_message = (
|
||||
f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
|
||||
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
|
||||
"error: " + message
|
||||
)
|
||||
updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
|
||||
if not has_blocklist:
|
||||
updated_message += (
|
||||
"Please file an issue with the following so that we can make "
|
||||
"`weights_only=True` compatible with your use case: WeightsUnpickler error: "
|
||||
)
|
||||
updated_message += message
|
||||
return updated_message + DOCS_MESSAGE
|
||||
|
||||
if weights_only is None:
|
||||
|
Reference in New Issue
Block a user