Add _codecs.encode and builtins.bytearray to _get_allowed_globals to support bytes and bytearray serialization (#133189)

Fixes #133163

Debugged in collaboration with @hariveliki

The `byte` type is demanding the global `_codecs.encode`. That means, the following currently works:
```python
import torch

torch.save(b'hello', '/tmp/dummy.pth')

torch.serialization.add_safe_globals([_codecs.encode])
torch.load('/tmp/dummy.pth', weights_only=True)
```

Similarly, `bytearray` needs `builtins.bytearray`.

Following the `torch.loads` docs promise, both types should be supported without `add_safe_globals` as they are both primitive types:
>         weights_only: Indicates whether unpickler should be restricted to
>            loading only tensors, primitive types, dictionaries
>           and any types added via :func:`torch.serialization.add_safe_globals`.

This PR adds both `_codecs.encode` and `builtins.bytearray` to `_get_allowed_globals` and test for saving and loading of both types with and without `weights_only`.

Co-authored-by: hariveliki <98284163+hariveliki@users.noreply.github.com>
Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133189
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Luciano Bello
2024-08-13 02:20:25 +00:00
committed by PyTorch MergeBot
parent f1c439cbed
commit 1e9bedf688
2 changed files with 15 additions and 0 deletions

View File

@ -4133,6 +4133,17 @@ class TestSerialization(TestCase, SerializationMixin):
y['even'][0] = torch.tensor(-0.25, dtype=dtype)
self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))
@parametrize('byte_literals', (b'byte', bytearray(b'bytearray')))
@parametrize('weights_only', (True, False))
def test_serialization_byte_literal(self, byte_literals, weights_only):
""" Tests that byte literal can be serialized.
See: https://github.com/pytorch/pytorch/issues/133163"""
with tempfile.NamedTemporaryFile() as f:
torch.save(byte_literals, f)
f.seek(0)
y = torch.load(f, weights_only=weights_only)
self.assertEqual(y, byte_literals)
@parametrize('filename', (True, False))
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")