mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add sets to list of safe objects to de-serialize (#138866)
Lists, dicts and tuples are already allowed, it's a bit weird not to exclude set from the list of basic containers. Test plan (in addition to unittest): ```python torch.save({1, 2, 3}, "foo.pt") torch.load("foo.pt", weights_only=True) ``` Fixes https://github.com/pytorch/pytorch/issues/138851 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138866 Approved by: https://github.com/mikaylagawarecki Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
907f001a68
commit
b999daf7a9
@ -4526,6 +4526,14 @@ class TestSubclassSerialization(TestCase):
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_sets_are_loadable_with_weights_only(self):
|
||||
s = {1, 2, 3}
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(s, f)
|
||||
f.seek(0)
|
||||
l_s = torch.load(f, weights_only=True)
|
||||
self.assertEqual(l_s, s)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
|
||||
def test_tensor_subclass_map_location(self):
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
|
Reference in New Issue
Block a user