mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Prevent _legacy_load with weights_only=True (#144914)"
This reverts commit 7c3aa1da1c97812af54d41f3f0eff2ef922c0f32. Reverted https://github.com/pytorch/pytorch/pull/144914 on behalf of https://github.com/izaitsevfb due to breaking inductor on trunk ([comment](https://github.com/pytorch/pytorch/pull/144914#issuecomment-2596922781))
This commit is contained in:
@ -110,14 +110,12 @@ class TestSerialization(TestCase):
|
||||
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
|
||||
torch.save(qmodule(input_tensor), expected_file)
|
||||
|
||||
# weights_only=False as file was saved in .tar format
|
||||
input_tensor = torch.load(input_file, weights_only=False)
|
||||
input_tensor = torch.load(input_file)
|
||||
# weights_only = False as sometimes get ScriptObject here
|
||||
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
|
||||
qmodule_scripted = torch.jit.load(scripted_module_file)
|
||||
qmodule_traced = torch.jit.load(traced_module_file)
|
||||
# weights_only=False as file was saved in .tar format
|
||||
expected = torch.load(expected_file, weights_only=False)
|
||||
expected = torch.load(expected_file)
|
||||
self.assertEqual(qmodule(input_tensor), expected, atol=prec)
|
||||
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
|
||||
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
|
||||
|
@ -227,6 +227,9 @@ class SerializationMixin:
|
||||
def test_serialization(self):
|
||||
self._test_serialization(False)
|
||||
|
||||
def test_serialization_safe(self):
|
||||
self._test_serialization(True)
|
||||
|
||||
def test_serialization_filelike(self):
|
||||
# Test serialization (load and save) with a filelike object
|
||||
b = self._test_serialization_data()
|
||||
@ -363,6 +366,9 @@ class SerializationMixin:
|
||||
def test_serialization_sparse(self):
|
||||
self._test_serialization(False)
|
||||
|
||||
def test_serialization_sparse_safe(self):
|
||||
self._test_serialization(True)
|
||||
|
||||
def test_serialization_sparse_invalid(self):
|
||||
x = torch.zeros(3, 3)
|
||||
x[1][1] = 1
|
||||
@ -508,6 +514,9 @@ class SerializationMixin:
|
||||
def test_serialization_backwards_compat(self):
|
||||
self._test_serialization_backwards_compat(False)
|
||||
|
||||
def test_serialization_backwards_compat_safe(self):
|
||||
self._test_serialization_backwards_compat(True)
|
||||
|
||||
def test_serialization_save_warnings(self):
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
with tempfile.NamedTemporaryFile() as checkpoint:
|
||||
@ -552,8 +561,7 @@ class SerializationMixin:
|
||||
def check_map_locations(map_locations, dtype, intended_device):
|
||||
for fileobject_lambda in fileobject_lambdas:
|
||||
for map_location in map_locations:
|
||||
# weigts_only=False as the downloaded file path uses the old serialization format
|
||||
tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False)
|
||||
tensor = torch.load(fileobject_lambda(), map_location=map_location)
|
||||
|
||||
self.assertEqual(tensor.device, intended_device)
|
||||
self.assertEqual(tensor.dtype, dtype)
|
||||
@ -596,8 +604,7 @@ class SerializationMixin:
|
||||
|
||||
error_msg = r'Attempting to deserialize object on a CUDA device'
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
# weights_only=False as serialized is in legacy format
|
||||
_ = torch.load(buf, weights_only=False)
|
||||
_ = torch.load(buf)
|
||||
|
||||
@unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681")
|
||||
def test_serialization_filelike_api_requirements(self):
|
||||
@ -717,8 +724,7 @@ class SerializationMixin:
|
||||
b'\x00\x00\x00\x00')
|
||||
|
||||
buf = io.BytesIO(serialized)
|
||||
# serialized was saved with PyTorch 0.3.1
|
||||
(s1, s2) = torch.load(buf, weights_only=False)
|
||||
(s1, s2) = torch.load(buf)
|
||||
self.assertEqual(s1[0], 0)
|
||||
self.assertEqual(s2[0], 0)
|
||||
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
|
||||
@ -835,24 +841,6 @@ class serialization_method:
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.save = self.torch_save
|
||||
|
||||
|
||||
# used to set weights_only=False in _use_new_zipfile_serialization=False tests
|
||||
class load_method:
|
||||
def __init__(self, weights_only):
|
||||
self.weights_only = weights_only
|
||||
self.torch_load = torch.load
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs['weights_only'] = self.weights_only
|
||||
return self.torch_load(*args, **kwargs)
|
||||
|
||||
torch.load = wrapper
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.load = self.torch_load
|
||||
|
||||
|
||||
Point = namedtuple('Point', ['x', 'y'])
|
||||
|
||||
class ClassThatUsesBuildInstruction:
|
||||
@ -889,7 +877,7 @@ class TestBothSerialization(TestCase):
|
||||
|
||||
torch.save(x, f_old, _use_new_zipfile_serialization=False)
|
||||
f_old.seek(0)
|
||||
x_old_load = torch.load(f_old, weights_only=False)
|
||||
x_old_load = torch.load(f_old, weights_only=weights_only)
|
||||
self.assertEqual(x_old_load, x_new_load)
|
||||
|
||||
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
|
||||
@ -897,17 +885,6 @@ class TestBothSerialization(TestCase):
|
||||
test(f_new, f_old)
|
||||
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
|
||||
|
||||
def test_old_serialization_fails_with_weights_only(self):
|
||||
a = torch.randn(5, 5)
|
||||
with BytesIOContext() as f:
|
||||
torch.save(a, f, _use_new_zipfile_serialization=False)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
|
||||
):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
|
||||
class TestOldSerialization(TestCase, SerializationMixin):
|
||||
# unique_key is necessary because on Python 2.7, if a warning passed to
|
||||
@ -983,7 +960,8 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
def test_serialization_offset_filelike(self):
|
||||
@parametrize('weights_only', (True, False))
|
||||
def test_serialization_offset_filelike(self, weights_only):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
|
||||
i, j = 41, 43
|
||||
@ -995,16 +973,16 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
|
||||
f.seek(0)
|
||||
i_loaded = pickle.load(f)
|
||||
a_loaded = torch.load(f)
|
||||
a_loaded = torch.load(f, weights_only=weights_only)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
b_loaded = torch.load(f, weights_only=weights_only)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=False), load_method(weights_only=False):
|
||||
with serialization_method(use_zip=False):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -1501,10 +1501,15 @@ def load(
|
||||
"please torch.save your checkpoint with this option in order to use mmap."
|
||||
)
|
||||
if weights_only:
|
||||
raise RuntimeError(
|
||||
"Cannot use ``weights_only=True`` with files saved in the "
|
||||
".tar format used before version 1.6. " + UNSAFE_MESSAGE
|
||||
)
|
||||
try:
|
||||
return _legacy_load(
|
||||
opened_file,
|
||||
map_location,
|
||||
_weights_only_unpickler,
|
||||
**pickle_load_args,
|
||||
)
|
||||
except pickle.UnpicklingError as e:
|
||||
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
||||
return _legacy_load(
|
||||
opened_file, map_location, pickle_module, **pickle_load_args
|
||||
)
|
||||
|
Reference in New Issue
Block a user