mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Merge torch.cuda._UntypedStorage into torch._UntypedStorage (#75459)
Fixes #74933 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75459 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac1837ddd3
commit
aea6e2c396
@ -880,19 +880,21 @@ class PackageExporter:
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
storage = obj._storage
|
||||
untyped_storage = obj._storage
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
dtype = obj.dtype
|
||||
storage_numel = obj.size()
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
elif isinstance(obj, torch._UntypedStorage):
|
||||
untyped_storage = obj
|
||||
storage_type = normalize_storage_type(type(storage))
|
||||
dtype = torch.uint8
|
||||
storage_numel = storage.nbytes()
|
||||
else:
|
||||
raise RuntimeError(f'storage type not recognized: {type(obj)}')
|
||||
|
||||
storage = cast(Storage, storage)
|
||||
storage: Storage = cast(Storage, untyped_storage)
|
||||
location = location_tag(storage)
|
||||
|
||||
# serialize storage if not already written
|
||||
|
||||
Reference in New Issue
Block a user