mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Throw error when saving storages that view same data with different type (#66949)
Summary: Fixes https://github.com/pytorch/pytorch/issues/58970 cc mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/66949 Reviewed By: albanD Differential Revision: D31926323 Pulled By: anjali411 fbshipit-source-id: f6e7acc0c1968b70a94f9b0b69a32780e8e21a62
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bf60c6e71b
commit
bc3d380ed1
@ -387,6 +387,12 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
serialized_container_types = {}
|
||||
serialized_storages = {}
|
||||
|
||||
# Since loading storages that view the same data with different dtypes is
|
||||
# not supported, we need to keep track of the dtype associated with each
|
||||
# storage data_ptr and throw an error if the dtype is ever different.
|
||||
# TODO: This feature could be added in the future
|
||||
storage_dtypes: Dict[int, torch.dtype] = {}
|
||||
|
||||
def persistent_id(obj: Any) -> Optional[Tuple]:
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
# but torch store returns tuples. This works only in the binary protocol
|
||||
@ -412,6 +418,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
storage_dtype = obj.dtype
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
dtype = obj.dtype
|
||||
@ -419,10 +426,19 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
storage_dtype = storage.dtype
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
dtype = torch.uint8
|
||||
storage_numel = cast(Storage, storage).nbytes()
|
||||
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
view_metadata: Optional[Tuple[str, int, int]]
|
||||
storage = cast(Storage, storage)
|
||||
|
||||
@ -507,6 +523,12 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
serialized_storages = {}
|
||||
id_map: Dict[int, str] = {}
|
||||
|
||||
# Since loading storages that view the same data with different dtypes is
|
||||
# not supported, we need to keep track of the dtype associated with each
|
||||
# storage data_ptr and throw an error if the dtype is ever different.
|
||||
# TODO: This feature could be added in the future
|
||||
storage_dtypes: Dict[int, torch.dtype] = {}
|
||||
|
||||
def persistent_id(obj):
|
||||
# FIXME: the docs say that persistent_id should only return a string
|
||||
# but torch store returns tuples. This works only in the binary protocol
|
||||
@ -519,16 +541,27 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
storage = obj._storage
|
||||
storage_dtype = obj.dtype
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
storage_numel = obj.size()
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
storage_dtype = storage.dtype
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
storage_numel = storage.nbytes()
|
||||
|
||||
storage = cast(Storage, storage)
|
||||
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
|
||||
location = location_tag(storage)
|
||||
serialized_storages[storage_key] = storage
|
||||
|
Reference in New Issue
Block a user