mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix all RuntimeErrors during weights_only load from being erroneously reported with the weights_only message (#132349)
Caught in above PR #127627 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132349 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d2be06d94
commit
c8ad5e37e8
@ -377,11 +377,12 @@ class SerializationMixin:
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save({"spoofed": TensorSerializationSpoofer(x)}, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size is inconsistent with indices"):
|
||||
y = torch.load(f)
|
||||
for weights_only in (False, True):
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size is inconsistent with indices"):
|
||||
y = torch.load(f, weights_only=weights_only)
|
||||
|
||||
def _test_serialization_sparse_compressed_invalid(self,
|
||||
conversion,
|
||||
|
@ -235,7 +235,7 @@ class Unpickler:
|
||||
module = IMPORT_MAPPING[module]
|
||||
full_path = f"{module}.{name}"
|
||||
if module in _blocklisted_modules:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
|
||||
)
|
||||
if full_path in _get_allowed_globals():
|
||||
@ -243,7 +243,7 @@ class Unpickler:
|
||||
elif full_path in _get_user_allowed_globals():
|
||||
self.append(_get_user_allowed_globals()[full_path])
|
||||
else:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
|
||||
"this global if you trust this class/function."
|
||||
@ -256,7 +256,9 @@ class Unpickler:
|
||||
elif cls in _get_user_allowed_globals().values():
|
||||
self.append(cls.__new__(cls, *args))
|
||||
else:
|
||||
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
|
||||
raise UnpicklingError(
|
||||
f"Trying to instantiate unsupported class {cls}"
|
||||
)
|
||||
elif key[0] == REDUCE[0]:
|
||||
args = self.stack.pop()
|
||||
func = self.stack[-1]
|
||||
@ -264,7 +266,7 @@ class Unpickler:
|
||||
func not in _get_allowed_globals().values()
|
||||
and func not in _get_user_allowed_globals().values()
|
||||
):
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Trying to call reduce for unrecognized function {func}"
|
||||
)
|
||||
self.stack[-1] = func(*args)
|
||||
@ -284,7 +286,7 @@ class Unpickler:
|
||||
else:
|
||||
inst.__dict__.update(state)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
|
||||
)
|
||||
# Stack manipulation
|
||||
@ -292,7 +294,7 @@ class Unpickler:
|
||||
item = self.stack.pop()
|
||||
list_obj = self.stack[-1]
|
||||
if type(list_obj) is not list:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Can only append to lists, but got {type(list_obj)}"
|
||||
)
|
||||
list_obj.append(item)
|
||||
@ -300,7 +302,7 @@ class Unpickler:
|
||||
items = self.pop_mark()
|
||||
list_obj = self.stack[-1]
|
||||
if type(list_obj) is not list:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Can only extend lists, but got {type(list_obj)}"
|
||||
)
|
||||
list_obj.extend(items)
|
||||
@ -350,7 +352,7 @@ class Unpickler:
|
||||
elif key[0] == BINUNICODE[0]:
|
||||
strlen = unpack("<I", read(4))[0]
|
||||
if strlen > maxsize:
|
||||
raise RuntimeError("String is too long")
|
||||
raise UnpicklingError("String is too long")
|
||||
strval = str(read(strlen), "utf-8", "surrogatepass")
|
||||
self.append(strval)
|
||||
elif key[0] == SHORT_BINSTRING[0]:
|
||||
@ -363,7 +365,7 @@ class Unpickler:
|
||||
pid = self.stack.pop()
|
||||
# Only allow persistent load of storage
|
||||
if type(pid) is not tuple and not type(pid) is not int:
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"persistent_load id must be tuple or int, but got {type(pid)}"
|
||||
)
|
||||
if (
|
||||
@ -371,7 +373,7 @@ class Unpickler:
|
||||
and len(pid) > 0
|
||||
and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
|
||||
):
|
||||
raise RuntimeError(
|
||||
raise UnpicklingError(
|
||||
f"Only persistent_load of storage is allowed, but got {pid[0]}"
|
||||
)
|
||||
self.append(self.persistent_load(pid))
|
||||
@ -401,7 +403,7 @@ class Unpickler:
|
||||
rc = self.stack.pop()
|
||||
return rc
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported operand {key[0]}")
|
||||
raise UnpicklingError(f"Unsupported operand {key[0]}")
|
||||
|
||||
# Return a list of items pushed in the stack after last MARK instruction.
|
||||
def pop_mark(self):
|
||||
|
@ -1258,7 +1258,7 @@ def load(
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
except pickle.UnpicklingError as e:
|
||||
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
||||
return _load(
|
||||
opened_zipfile,
|
||||
@ -1282,7 +1282,7 @@ def load(
|
||||
_weights_only_unpickler,
|
||||
**pickle_load_args,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
except pickle.UnpicklingError as e:
|
||||
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
||||
return _legacy_load(
|
||||
opened_file, map_location, pickle_module, **pickle_load_args
|
||||
|
Reference in New Issue
Block a user