Add sets to list of safe objects to de-serialize (#138866)

Lists, dicts and tuples are already allowed, it's a bit weird not to exclude set from the list of basic containers.

Test plan (in addition to unittest):
```python
torch.save({1, 2, 3}, "foo.pt")
torch.load("foo.pt", weights_only=True)
```

Fixes https://github.com/pytorch/pytorch/issues/138851

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138866
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
Nikita Shulga
2024-10-25 05:23:08 +00:00
committed by PyTorch MergeBot
parent 907f001a68
commit b999daf7a9
2 changed files with 9 additions and 0 deletions

View File

@ -4526,6 +4526,14 @@ class TestSubclassSerialization(TestCase):
finally:
torch.serialization.clear_safe_globals()
def test_sets_are_loadable_with_weights_only(self):
s = {1, 2, 3}
with tempfile.NamedTemporaryFile() as f:
torch.save(s, f)
f.seek(0)
l_s = torch.load(f, weights_only=True)
self.assertEqual(l_s, s)
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))