[BE] Do not warn when safely loading legacy dicts (#113614)

Use the same strategy as for unsafe pickler, i.e. use dummy `torch.serialization.StorageType` to represent legacy typed storage classes during deserialization. Add `_dtype` property to be able to use it for both new and legacy format deserialization.

Parametrize `test_serialization_new_format_old_format_compat`

Add regression test to validate that loading legacy modes can be done
without any warnings

Before the change:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_safe_cpu (__main__.TestBothSerializationCPU) ... /Users/nshulga/git/pytorch/pytorch/torch/_utils.py:836: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
ok

----------------------------------------------------------------------
Ran 2 tests in 0.116s

OK
```
Without the change but update test to catch warnings:
```
 % python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_weights_only_False_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ... FAIL

======================================================================
FAIL: test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2536, in wrapper
    method(*args, **kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/test/test_serialization.py", line 807, in test_serialization_new_format_old_format_compat
    self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
AssertionError: False is not true : Expected no warnings but got ["{message : UserWarning('TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()'), category : 'UserWarning', filename : '/Users/nshulga/git/pytorch/pytorch/torch/_utils.py', lineno : 836, line : None}"]

To execute this test, run the following from the base repo dir:
     python test/test_serialization.py -k test_serialization_new_format_old_format_compat_weights_only_True_cpu

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 2 tests in 0.109s

FAILED (failures=1)

```

Fixes problem reported in https://github.com/pytorch/pytorch/issues/52181#issuecomment-1715738910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113614
Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
Nikita Shulga
2023-11-14 22:09:10 +00:00
committed by PyTorch MergeBot
parent 538114db65
commit 1d640566d4
3 changed files with 19 additions and 12 deletions

View File

@ -25,7 +25,7 @@ from torch.serialization import check_module_version_greater_or_equal, get_defau
from torch.testing._internal.common_utils import IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, \
TestCase, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, \
parametrize, instantiate_parametrized_tests
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_dtype import all_types_and_complex_and
@ -786,7 +786,8 @@ class serialization_method:
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
class TestBothSerialization(TestCase):
def _test_serialization_new_format_old_format_compat(self, device, weights_only):
@parametrize("weights_only", (True, False))
def test_serialization_new_format_old_format_compat(self, device, weights_only):
x = [torch.ones(200, 200, device=device) for i in range(30)]
def test(f_new, f_old):
@ -800,14 +801,10 @@ class TestBothSerialization(TestCase):
x_old_load = torch.load(f_old, weights_only=weights_only)
self.assertEqual(x_old_load, x_new_load)
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
test(f_new, f_old)
def test_serialization_new_format_old_format_compat(self, device):
self._test_serialization_new_format_old_format_compat(device, False)
def test_serialization_new_format_old_format_compat_safe(self, device):
self._test_serialization_new_format_old_format_compat(device, True)
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
test(f_new, f_old)
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
class TestOldSerialization(TestCase, SerializationMixin):

View File

@ -93,7 +93,13 @@ def _get_allowed_globals():
rc[f"{tt.__module__}.{tt.__name__}"] = tt
# Storage classes
for ts in torch._storage_classes:
rc[f"{ts.__module__}.{ts.__name__}"] = ts
if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
# Wrap legacy storage types in a dummy class
rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
ts.__name__
)
else:
rc[f"{ts.__module__}.{ts.__name__}"] = ts
# Rebuild functions
for f in [
torch._utils._rebuild_parameter,

View File

@ -1320,7 +1320,11 @@ def _get_restore_location(map_location):
class StorageType:
def __init__(self, name):
self.dtype = _get_dtype_from_pickle_storage_type(name)
self._dtype = _get_dtype_from_pickle_storage_type(name)
@property
def dtype(self):
return self._dtype
def __str__(self):
return f'StorageType(dtype={self.dtype})'