mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add _codecs.encode
and builtins.bytearray
to _get_allowed_globals
to support bytes and bytearray serialization (#133189)
Fixes #133163 Debugged in collaboration with @hariveliki The `byte` type is demanding the global `_codecs.encode`. That means, the following currently works: ```python import torch torch.save(b'hello', '/tmp/dummy.pth') torch.serialization.add_safe_globals([_codecs.encode]) torch.load('/tmp/dummy.pth', weights_only=True) ``` Similarly, `bytearray` needs `builtins.bytearray`. Following the `torch.loads` docs promise, both types should be supported without `add_safe_globals` as they are both primitive types: > weights_only: Indicates whether unpickler should be restricted to > loading only tensors, primitive types, dictionaries > and any types added via :func:`torch.serialization.add_safe_globals`. This PR adds both `_codecs.encode` and `builtins.bytearray` to `_get_allowed_globals` and test for saving and loading of both types with and without `weights_only`. Co-authored-by: hariveliki <98284163+hariveliki@users.noreply.github.com> Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133189 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
f1c439cbed
commit
1e9bedf688
@ -4133,6 +4133,17 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
y['even'][0] = torch.tensor(-0.25, dtype=dtype)
|
||||
self.assertEqual(y['x'][:2].to(dtype=torch.float32), torch.tensor([-0.25, 0.25]))
|
||||
|
||||
@parametrize('byte_literals', (b'byte', bytearray(b'bytearray')))
|
||||
@parametrize('weights_only', (True, False))
|
||||
def test_serialization_byte_literal(self, byte_literals, weights_only):
|
||||
""" Tests that byte literal can be serialized.
|
||||
See: https://github.com/pytorch/pytorch/issues/133163"""
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(byte_literals, f)
|
||||
f.seek(0)
|
||||
y = torch.load(f, weights_only=weights_only)
|
||||
self.assertEqual(y, byte_literals)
|
||||
|
||||
@parametrize('filename', (True, False))
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")
|
||||
|
Reference in New Issue
Block a user