mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix failures when default is flipped for weights_only (#127627)
Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627 Approved by: https://github.com/albanD ghstack dependencies: #132349
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8ad5e37e8
commit
d9576c9440
@ -25,6 +25,7 @@ from torch.serialization import (
|
||||
check_module_version_greater_or_equal,
|
||||
get_default_load_endianness,
|
||||
LoadEndianness,
|
||||
safe_globals,
|
||||
set_default_load_endianness,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
@ -276,7 +277,8 @@ class SerializationMixin:
|
||||
torch.save(x, f, pickle_module=dill)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(ValueError, 'supports dill >='):
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
|
||||
|
||||
def test_pickle_module(self):
|
||||
class ThrowingUnpickler(pickle.Unpickler):
|
||||
@ -292,7 +294,8 @@ class SerializationMixin:
|
||||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
|
||||
torch.load(f, pickle_module=ThrowingModule)
|
||||
# weights_only=False as True does not support custom pickle module
|
||||
torch.load(f, pickle_module=ThrowingModule, weights_only=False)
|
||||
f.seek(0)
|
||||
z = torch.load(f)
|
||||
self.assertEqual(x, z)
|
||||
@ -307,11 +310,13 @@ class SerializationMixin:
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(x, f, pickle_module=dill)
|
||||
f.seek(0)
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
|
||||
# weights_only=False as True does not support custom pickle_module
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
|
||||
self.assertIsInstance(x2, type(x))
|
||||
self.assertEqual(x, x2)
|
||||
f.seek(0)
|
||||
x3 = torch.load(f, pickle_module=dill)
|
||||
# weights_only=False as True does not support custom pickle_module
|
||||
x3 = torch.load(f, pickle_module=dill, weights_only=False)
|
||||
self.assertIsInstance(x3, type(x))
|
||||
self.assertEqual(x, x3)
|
||||
|
||||
@ -718,13 +723,15 @@ class SerializationMixin:
|
||||
# This Pickle contains a Python 2 module with Unicode data and the
|
||||
# loading should fail if the user explicitly specifies ascii encoding!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii', weights_only=False))
|
||||
|
||||
def test_load_python2_unicode_module(self):
|
||||
# This Pickle contains some Unicode data!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
self.assertIsNotNone(torch.load(path, weights_only=False))
|
||||
|
||||
def test_load_error_msg(self):
|
||||
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
|
||||
@ -735,7 +742,8 @@ class SerializationMixin:
|
||||
delattr(resource, "tell")
|
||||
delattr(resource, "seek")
|
||||
with self.assertRaisesRegex(AttributeError, expected_err_msg):
|
||||
torch.load(resource)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
torch.load(resource, weights_only=False)
|
||||
|
||||
def test_save_different_dtype_unallocated(self):
|
||||
devices = ['cpu']
|
||||
@ -881,12 +889,11 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
# First check that the checkpoint can be loaded without warning about unsafe loads
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(checkpoint, weights_only=False)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertTrue("You are using `torch.load` with `weights_only=False`" in str(w[0].message))
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# Replace the module with different source
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
@ -894,12 +901,12 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
module = import_module(tmpmodule_name, fname)
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(checkpoint, weights_only=False)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 2)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertEqual(w[1].category, SourceChangeWarning)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, SourceChangeWarning)
|
||||
|
||||
def test_serialization_container(self):
|
||||
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
|
||||
@ -924,7 +931,8 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
||||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
m_loaded = torch.load(f)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m_loaded = torch.load(f, weights_only=False)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
|
||||
@ -1141,7 +1149,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
# safe_globals doesn't work even with allowlist
|
||||
with torch.serialization.safe_globals([os.execv]):
|
||||
with safe_globals([os.execv]):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@ -4252,7 +4260,8 @@ class TestSubclassSerialization(TestCase):
|
||||
with BytesIOContext() as f:
|
||||
torch.save(my_tensor, f)
|
||||
f.seek(0)
|
||||
new_tensor = torch.load(f)
|
||||
with safe_globals([TestWrapperSubclass]):
|
||||
new_tensor = torch.load(f)
|
||||
|
||||
self.assertIsInstance(new_tensor, TestWrapperSubclass)
|
||||
self.assertEqual(new_tensor.elem, my_tensor.elem)
|
||||
@ -4269,7 +4278,8 @@ class TestSubclassSerialization(TestCase):
|
||||
with BytesIOContext() as f:
|
||||
torch.save(my_tensor, f)
|
||||
f.seek(0)
|
||||
new_tensor = torch.load(f)
|
||||
with safe_globals([TestGetStateSubclass]):
|
||||
new_tensor = torch.load(f)
|
||||
|
||||
self.assertIsInstance(new_tensor, TestGetStateSubclass)
|
||||
self.assertEqual(new_tensor.elem, my_tensor.elem)
|
||||
@ -4306,7 +4316,8 @@ class TestSubclassSerialization(TestCase):
|
||||
with BytesIOContext() as f:
|
||||
torch.save(tensor, f)
|
||||
f.seek(0)
|
||||
tensor2 = torch.load(f)
|
||||
with safe_globals([TestEmptySubclass]):
|
||||
tensor2 = torch.load(f)
|
||||
|
||||
tensor = TestEmptySubclass()
|
||||
# Ensures it runs fine
|
||||
@ -4316,7 +4327,8 @@ class TestSubclassSerialization(TestCase):
|
||||
with BytesIOContext() as f:
|
||||
torch.save(tensor, f)
|
||||
f.seek(0)
|
||||
tensor2 = torch.load(f)
|
||||
with safe_globals([TestEmptySubclass]):
|
||||
tensor2 = torch.load(f)
|
||||
|
||||
@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
|
||||
def test_safe_globals_for_weights_only(self):
|
||||
@ -4357,7 +4369,7 @@ class TestSubclassSerialization(TestCase):
|
||||
|
||||
def test_safe_globals_context_manager_weights_only(self):
|
||||
'''
|
||||
Tests torch.serialization.safe_globals context manager
|
||||
Tests safe_globals context manager
|
||||
'''
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
p = torch.nn.Parameter(t)
|
||||
@ -4367,7 +4379,7 @@ class TestSubclassSerialization(TestCase):
|
||||
torch.serialization.add_safe_globals([TestEmptySubclass])
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
with torch.serialization.safe_globals([TwoTensor]):
|
||||
with safe_globals([TwoTensor]):
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
|
||||
@ -4385,15 +4397,16 @@ class TestSubclassSerialization(TestCase):
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
|
||||
# make sure map_location is not propagated over multiple torch.load calls
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
|
||||
with safe_globals([TwoTensor]):
|
||||
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
|
||||
# make sure map_location is not propagated over multiple torch.load calls
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
|
||||
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||
|
Reference in New Issue
Block a user