Blocklist certain modules for weights_only load (#131259)

Also bold certain text in the error message as suggested
<img width="3000" alt="Screenshot 2024-07-19 at 5 56 48 PM" src="https://github.com/user-attachments/assets/378f20c5-c6b2-4e53-8eaf-0bd26c3a6b60">

With a GLOBAL like `os.execv` the error message is now as such

```python
File "/data/users/mg1998/pytorch/torch/serialization.py", line 1256, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Trying to load unsupported GLOBAL posix.execv whose module posix is blocked.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131259
Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-07-22 09:06:57 -07:00
committed by PyTorch MergeBot
parent 93ef2e53f8
commit d3556786b8
3 changed files with 40 additions and 8 deletions

View File

@ -1130,6 +1130,20 @@ class TestSerialization(TestCase, SerializationMixin):
torch.serialization.clear_safe_globals()
ClassThatUsesBuildInstruction.__setstate__ = None
def test_weights_only_safe_globals_blocklist(self):
module = 'nt' if IS_WINDOWS else 'posix'
error_msg = f"unsupported GLOBAL {module}.execv whose module {module} is blocked"
with BytesIOContext() as f:
torch.save(os.execv, f)
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
torch.load(f, weights_only=True)
f.seek(0)
# safe_globals doesn't work even with allowlist
with torch.serialization.safe_globals([os.execv]):
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
torch.load(f, weights_only=True)
@parametrize("unsafe_global", [True, False])
def test_weights_only_error(self, unsafe_global):
sd = {'t': TwoTensor(torch.randn(2), torch.randn(2))}

View File

@ -72,6 +72,15 @@ import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
# modules in this list are never allowed, even if the user attempts to allowlist
# functions/classes from them
_blocklisted_modules = [
"sys",
"os",
"posix",
"nt",
]
_marked_safe_globals_list: List[Any] = []
@ -221,6 +230,10 @@ class Unpickler:
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if module in _blocklisted_modules:
raise RuntimeError(
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
)
if full_path in _get_allowed_globals():
self.append(_get_allowed_globals()[full_path])
elif full_path in _get_user_allowed_globals():

View File

@ -1145,21 +1145,26 @@ def load(
)
def _get_wo_message(message: str) -> str:
pattern = r"GLOBAL (\S+) was not an allowed global by default."
has_unsafe_global = re.search(pattern, message) is not None
unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
blocklist_pattern = r"whose module (\S+) is blocked"
has_blocklist = re.search(blocklist_pattern, message) is not None
if has_unsafe_global:
updated_message = (
"Weights only load failed. This file can still be loaded, to do so you have two options "
"Weights only load failed. This file can still be loaded, to do so you have two options, "
"\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
+ message
)
else:
updated_message = (
f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
"error: " + message
)
updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
if not has_blocklist:
updated_message += (
"Please file an issue with the following so that we can make "
"`weights_only=True` compatible with your use case: WeightsUnpickler error: "
)
updated_message += message
return updated_message + DOCS_MESSAGE
if weights_only is None: