mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Avoid dtype mismatch error in torch.save
if storages are unallocated (#68787)
Summary: Fixes https://github.com/pytorch/pytorch/issues/58970 cc mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/68787 Reviewed By: mruberry Differential Revision: D32617425 Pulled By: anjali411 fbshipit-source-id: fe7f2374e4ef4428346a0a202cae8e0d382e03ab
This commit is contained in:
committed by
Facebook GitHub Bot
parent
208e109dbf
commit
b69155f754
@ -431,13 +431,17 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
dtype = torch.uint8
|
||||
storage_numel = cast(Storage, storage).nbytes()
|
||||
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
# If storage is allocated, ensure that any other saved storages
|
||||
# pointing to the same data all have the same dtype. If storage is
|
||||
# not allocated, don't perform this check
|
||||
if storage.data_ptr() != 0:
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
view_metadata: Optional[Tuple[str, int, int]]
|
||||
storage = cast(Storage, storage)
|
||||
@ -554,13 +558,17 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
|
||||
storage = cast(Storage, storage)
|
||||
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
# If storage is allocated, ensure that any other saved storages
|
||||
# pointing to the same data all have the same dtype. If storage is
|
||||
# not allocated, don't perform this check
|
||||
if storage.data_ptr() != 0:
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
'Cannot save multiple tensors or storages that '
|
||||
'view the same data as different types')
|
||||
else:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
|
||||
location = location_tag(storage)
|
||||
|
Reference in New Issue
Block a user