Fix failures when default is flipped for weights_only (#127627)

Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
This commit is contained in:
Mikayla Gawarecki
2024-08-15 19:48:35 +00:00
committed by PyTorch MergeBot
parent c8ad5e37e8
commit d9576c9440
22 changed files with 135 additions and 78 deletions

View File

@ -25,6 +25,7 @@ from torch.serialization import (
check_module_version_greater_or_equal,
get_default_load_endianness,
LoadEndianness,
safe_globals,
set_default_load_endianness,
SourceChangeWarning,
)
@ -276,7 +277,8 @@ class SerializationMixin:
torch.save(x, f, pickle_module=dill)
f.seek(0)
with self.assertRaisesRegex(ValueError, 'supports dill >='):
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
# weights_only=False as this is legacy code that saves the model
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
def test_pickle_module(self):
class ThrowingUnpickler(pickle.Unpickler):
@ -292,7 +294,8 @@ class SerializationMixin:
torch.save(x, f)
f.seek(0)
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
torch.load(f, pickle_module=ThrowingModule)
# weights_only=False as True does not support custom pickle module
torch.load(f, pickle_module=ThrowingModule, weights_only=False)
f.seek(0)
z = torch.load(f)
self.assertEqual(x, z)
@ -307,11 +310,13 @@ class SerializationMixin:
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f, pickle_module=dill)
f.seek(0)
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
# weights_only=False as True does not support custom pickle_module
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
self.assertIsInstance(x2, type(x))
self.assertEqual(x, x2)
f.seek(0)
x3 = torch.load(f, pickle_module=dill)
# weights_only=False as True does not support custom pickle_module
x3 = torch.load(f, pickle_module=dill, weights_only=False)
self.assertIsInstance(x3, type(x))
self.assertEqual(x, x3)
@ -718,13 +723,15 @@ class SerializationMixin:
# This Pickle contains a Python 2 module with Unicode data and the
# loading should fail if the user explicitly specifies ascii encoding!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
# weights_only=False as this is legacy code that saves the model
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii', weights_only=False))
def test_load_python2_unicode_module(self):
# This Pickle contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
with warnings.catch_warnings(record=True) as w:
self.assertIsNotNone(torch.load(path))
# weights_only=False as this is legacy code that saves the model
self.assertIsNotNone(torch.load(path, weights_only=False))
def test_load_error_msg(self):
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
@ -735,7 +742,8 @@ class SerializationMixin:
delattr(resource, "tell")
delattr(resource, "seek")
with self.assertRaisesRegex(AttributeError, expected_err_msg):
torch.load(resource)
# weights_only=False as this is legacy code that saves the model
torch.load(resource, weights_only=False)
def test_save_different_dtype_unallocated(self):
devices = ['cpu']
@ -881,12 +889,11 @@ class TestOldSerialization(TestCase, SerializationMixin):
# First check that the checkpoint can be loaded without warning about unsafe loads
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(checkpoint, weights_only=False)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, FutureWarning)
self.assertTrue("You are using `torch.load` with `weights_only=False`" in str(w[0].message))
self.assertEqual(len(w), 0)
# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@ -894,12 +901,12 @@ class TestOldSerialization(TestCase, SerializationMixin):
module = import_module(tmpmodule_name, fname)
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(checkpoint, weights_only=False)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 2)
self.assertEqual(w[0].category, FutureWarning)
self.assertEqual(w[1].category, SourceChangeWarning)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, SourceChangeWarning)
def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
@ -924,7 +931,8 @@ class TestOldSerialization(TestCase, SerializationMixin):
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
m_loaded = torch.load(f)
# weights_only=False as this is legacy code that saves the model
m_loaded = torch.load(f, weights_only=False)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
@ -1141,7 +1149,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.load(f, weights_only=True)
f.seek(0)
# safe_globals doesn't work even with allowlist
with torch.serialization.safe_globals([os.execv]):
with safe_globals([os.execv]):
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
torch.load(f, weights_only=True)
@ -4252,7 +4260,8 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(my_tensor, f)
f.seek(0)
new_tensor = torch.load(f)
with safe_globals([TestWrapperSubclass]):
new_tensor = torch.load(f)
self.assertIsInstance(new_tensor, TestWrapperSubclass)
self.assertEqual(new_tensor.elem, my_tensor.elem)
@ -4269,7 +4278,8 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(my_tensor, f)
f.seek(0)
new_tensor = torch.load(f)
with safe_globals([TestGetStateSubclass]):
new_tensor = torch.load(f)
self.assertIsInstance(new_tensor, TestGetStateSubclass)
self.assertEqual(new_tensor.elem, my_tensor.elem)
@ -4306,7 +4316,8 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(tensor, f)
f.seek(0)
tensor2 = torch.load(f)
with safe_globals([TestEmptySubclass]):
tensor2 = torch.load(f)
tensor = TestEmptySubclass()
# Ensures it runs fine
@ -4316,7 +4327,8 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(tensor, f)
f.seek(0)
tensor2 = torch.load(f)
with safe_globals([TestEmptySubclass]):
tensor2 = torch.load(f)
@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
def test_safe_globals_for_weights_only(self):
@ -4357,7 +4369,7 @@ class TestSubclassSerialization(TestCase):
def test_safe_globals_context_manager_weights_only(self):
'''
Tests torch.serialization.safe_globals context manager
Tests safe_globals context manager
'''
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
p = torch.nn.Parameter(t)
@ -4367,7 +4379,7 @@ class TestSubclassSerialization(TestCase):
torch.serialization.add_safe_globals([TestEmptySubclass])
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
with torch.serialization.safe_globals([TwoTensor]):
with safe_globals([TwoTensor]):
f.seek(0)
torch.load(f, weights_only=True)
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
@ -4385,15 +4397,16 @@ class TestSubclassSerialization(TestCase):
with TemporaryFileName() as f:
torch.save(sd, f)
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
# make sure map_location is not propagated over multiple torch.load calls
sd_loaded = torch.load(f)
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
# make sure map_location is not propagated over multiple torch.load calls
sd_loaded = torch.load(f)
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
instantiate_device_type_tests(TestBothSerialization, globals())