Fix all RuntimeErrors during weights_only load from being erroneously reported with the weights_only message (#132349)

Caught in above PR #127627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132349
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-08-15 19:48:35 +00:00
committed by PyTorch MergeBot
parent 0d2be06d94
commit c8ad5e37e8
3 changed files with 21 additions and 18 deletions

View File

@ -377,11 +377,12 @@ class SerializationMixin:
with tempfile.NamedTemporaryFile() as f:
torch.save({"spoofed": TensorSerializationSpoofer(x)}, f)
f.seek(0)
with self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
y = torch.load(f)
for weights_only in (False, True):
f.seek(0)
with self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
y = torch.load(f, weights_only=weights_only)
def _test_serialization_sparse_compressed_invalid(self,
conversion,

View File

@ -235,7 +235,7 @@ class Unpickler:
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if module in _blocklisted_modules:
raise RuntimeError(
raise UnpicklingError(
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
)
if full_path in _get_allowed_globals():
@ -243,7 +243,7 @@ class Unpickler:
elif full_path in _get_user_allowed_globals():
self.append(_get_user_allowed_globals()[full_path])
else:
raise RuntimeError(
raise UnpicklingError(
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
"this global if you trust this class/function."
@ -256,7 +256,9 @@ class Unpickler:
elif cls in _get_user_allowed_globals().values():
self.append(cls.__new__(cls, *args))
else:
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
raise UnpicklingError(
f"Trying to instantiate unsupported class {cls}"
)
elif key[0] == REDUCE[0]:
args = self.stack.pop()
func = self.stack[-1]
@ -264,7 +266,7 @@ class Unpickler:
func not in _get_allowed_globals().values()
and func not in _get_user_allowed_globals().values()
):
raise RuntimeError(
raise UnpicklingError(
f"Trying to call reduce for unrecognized function {func}"
)
self.stack[-1] = func(*args)
@ -284,7 +286,7 @@ class Unpickler:
else:
inst.__dict__.update(state)
else:
raise RuntimeError(
raise UnpicklingError(
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
)
# Stack manipulation
@ -292,7 +294,7 @@ class Unpickler:
item = self.stack.pop()
list_obj = self.stack[-1]
if type(list_obj) is not list:
raise RuntimeError(
raise UnpicklingError(
f"Can only append to lists, but got {type(list_obj)}"
)
list_obj.append(item)
@ -300,7 +302,7 @@ class Unpickler:
items = self.pop_mark()
list_obj = self.stack[-1]
if type(list_obj) is not list:
raise RuntimeError(
raise UnpicklingError(
f"Can only extend lists, but got {type(list_obj)}"
)
list_obj.extend(items)
@ -350,7 +352,7 @@ class Unpickler:
elif key[0] == BINUNICODE[0]:
strlen = unpack("<I", read(4))[0]
if strlen > maxsize:
raise RuntimeError("String is too long")
raise UnpicklingError("String is too long")
strval = str(read(strlen), "utf-8", "surrogatepass")
self.append(strval)
elif key[0] == SHORT_BINSTRING[0]:
@ -363,7 +365,7 @@ class Unpickler:
pid = self.stack.pop()
# Only allow persistent load of storage
if type(pid) is not tuple and not type(pid) is not int:
raise RuntimeError(
raise UnpicklingError(
f"persistent_load id must be tuple or int, but got {type(pid)}"
)
if (
@ -371,7 +373,7 @@ class Unpickler:
and len(pid) > 0
and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
):
raise RuntimeError(
raise UnpicklingError(
f"Only persistent_load of storage is allowed, but got {pid[0]}"
)
self.append(self.persistent_load(pid))
@ -401,7 +403,7 @@ class Unpickler:
rc = self.stack.pop()
return rc
else:
raise RuntimeError(f"Unsupported operand {key[0]}")
raise UnpicklingError(f"Unsupported operand {key[0]}")
# Return a list of items pushed in the stack after last MARK instruction.
def pop_mark(self):

View File

@ -1258,7 +1258,7 @@ def load(
overall_storage=overall_storage,
**pickle_load_args,
)
except RuntimeError as e:
except pickle.UnpicklingError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _load(
opened_zipfile,
@ -1282,7 +1282,7 @@ def load(
_weights_only_unpickler,
**pickle_load_args,
)
except RuntimeError as e:
except pickle.UnpicklingError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _legacy_load(
opened_file, map_location, pickle_module, **pickle_load_args