mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Do not warn when safely loading legacy dicts (#113614)
Use the same strategy as for unsafe pickler, i.e. use dummy `torch.serialization.StorageType` to represent legacy typed storage classes during deserialization. Add `_dtype` property to be able to use it for both new and legacy format deserialization. Parametrize `test_serialization_new_format_old_format_compat` Add regression test to validate that loading legacy modes can be done without any warnings Before the change: ``` % python test_serialization.py -v -k test_serialization_new_format_old_format_compat_ test_serialization_new_format_old_format_compat_cpu (__main__.TestBothSerializationCPU) ... ok test_serialization_new_format_old_format_compat_safe_cpu (__main__.TestBothSerializationCPU) ... /Users/nshulga/git/pytorch/pytorch/torch/_utils.py:836: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage() return self.fget.__get__(instance, owner)() ok ---------------------------------------------------------------------- Ran 2 tests in 0.116s OK ``` Without the change but update test to catch warnings: ``` % python test_serialization.py -v -k test_serialization_new_format_old_format_compat_ test_serialization_new_format_old_format_compat_weights_only_False_cpu (__main__.TestBothSerializationCPU) ... ok test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ... FAIL ====================================================================== FAIL: test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2536, in wrapper method(*args, **kwargs) File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test result = test(self, **param_kwargs) File "/Users/nshulga/git/pytorch/pytorch/test/test_serialization.py", line 807, in test_serialization_new_format_old_format_compat self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}") AssertionError: False is not true : Expected no warnings but got ["{message : UserWarning('TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()'), category : 'UserWarning', filename : '/Users/nshulga/git/pytorch/pytorch/torch/_utils.py', lineno : 836, line : None}"] To execute this test, run the following from the base repo dir: python test/test_serialization.py -k test_serialization_new_format_old_format_compat_weights_only_True_cpu This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- Ran 2 tests in 0.109s FAILED (failures=1) ``` Fixes problem reported in https://github.com/pytorch/pytorch/issues/52181#issuecomment-1715738910 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113614 Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
538114db65
commit
1d640566d4
@ -25,7 +25,7 @@ from torch.serialization import check_module_version_greater_or_equal, get_defau
|
||||
|
||||
from torch.testing._internal.common_utils import IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, \
|
||||
TestCase, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, \
|
||||
parametrize, instantiate_parametrized_tests
|
||||
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
|
||||
@ -786,7 +786,8 @@ class serialization_method:
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
class TestBothSerialization(TestCase):
|
||||
def _test_serialization_new_format_old_format_compat(self, device, weights_only):
|
||||
@parametrize("weights_only", (True, False))
|
||||
def test_serialization_new_format_old_format_compat(self, device, weights_only):
|
||||
x = [torch.ones(200, 200, device=device) for i in range(30)]
|
||||
|
||||
def test(f_new, f_old):
|
||||
@ -800,14 +801,10 @@ class TestBothSerialization(TestCase):
|
||||
x_old_load = torch.load(f_old, weights_only=weights_only)
|
||||
self.assertEqual(x_old_load, x_new_load)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
|
||||
test(f_new, f_old)
|
||||
|
||||
def test_serialization_new_format_old_format_compat(self, device):
|
||||
self._test_serialization_new_format_old_format_compat(device, False)
|
||||
|
||||
def test_serialization_new_format_old_format_compat_safe(self, device):
|
||||
self._test_serialization_new_format_old_format_compat(device, True)
|
||||
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
|
||||
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
|
||||
test(f_new, f_old)
|
||||
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
|
||||
|
||||
|
||||
class TestOldSerialization(TestCase, SerializationMixin):
|
||||
|
@ -93,7 +93,13 @@ def _get_allowed_globals():
|
||||
rc[f"{tt.__module__}.{tt.__name__}"] = tt
|
||||
# Storage classes
|
||||
for ts in torch._storage_classes:
|
||||
rc[f"{ts.__module__}.{ts.__name__}"] = ts
|
||||
if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
|
||||
# Wrap legacy storage types in a dummy class
|
||||
rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
|
||||
ts.__name__
|
||||
)
|
||||
else:
|
||||
rc[f"{ts.__module__}.{ts.__name__}"] = ts
|
||||
# Rebuild functions
|
||||
for f in [
|
||||
torch._utils._rebuild_parameter,
|
||||
|
@ -1320,7 +1320,11 @@ def _get_restore_location(map_location):
|
||||
|
||||
class StorageType:
|
||||
def __init__(self, name):
|
||||
self.dtype = _get_dtype_from_pickle_storage_type(name)
|
||||
self._dtype = _get_dtype_from_pickle_storage_type(name)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def __str__(self):
|
||||
return f'StorageType(dtype={self.dtype})'
|
||||
|
Reference in New Issue
Block a user