Throw error when saving storages that view same data with different type (#66949)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/58970

cc mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66949

Reviewed By: albanD

Differential Revision: D31926323

Pulled By: anjali411

fbshipit-source-id: f6e7acc0c1968b70a94f9b0b69a32780e8e21a62
This commit is contained in:
Kurt Mohler
2021-11-16 08:41:14 -08:00
committed by Facebook GitHub Bot
parent bf60c6e71b
commit bc3d380ed1
2 changed files with 66 additions and 0 deletions

View File

@ -387,6 +387,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
serialized_container_types = {}
serialized_storages = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
@ -412,6 +418,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
@ -419,10 +426,19 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
else:
storage = obj
storage_dtype = storage.dtype
storage_type = normalize_storage_type(type(obj))
dtype = torch.uint8
storage_numel = cast(Storage, storage).nbytes()
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
view_metadata: Optional[Tuple[str, int, int]]
storage = cast(Storage, storage)
@ -507,6 +523,12 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {}
id_map: Dict[int, str] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
@ -519,16 +541,27 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj.size()
else:
storage = obj
storage_dtype = storage.dtype
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage = cast(Storage, storage)
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
location = location_tag(storage)
serialized_storages[storage_key] = storage