Merge torch.cuda._UntypedStorage into torch._UntypedStorage (#75459)

Fixes #74933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75459
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-05-19 13:54:37 +00:00
committed by PyTorch MergeBot
parent ac1837ddd3
commit aea6e2c396
32 changed files with 357 additions and 565 deletions

View File

@ -880,19 +880,21 @@ class PackageExporter:
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
untyped_storage = obj._storage
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
else:
storage = obj
elif isinstance(obj, torch._UntypedStorage):
untyped_storage = obj
storage_type = normalize_storage_type(type(storage))
dtype = torch.uint8
storage_numel = storage.nbytes()
else:
raise RuntimeError(f'storage type not recognized: {type(obj)}')
storage = cast(Storage, storage)
storage: Storage = cast(Storage, untyped_storage)
location = location_tag(storage)
# serialize storage if not already written