mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix nn serialization errors
This commit is contained in:
@ -54,9 +54,13 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol)
|
||||
for key, tensor in serialized_tensors.items():
|
||||
storage = tensor.storage()
|
||||
serialized_storages[storage._cdata] = storage
|
||||
if storage is not None:
|
||||
storage_id = storage._cdata
|
||||
serialized_storages[storage_id] = storage
|
||||
else:
|
||||
storage_id = None
|
||||
|
||||
pickle_module.dump((key, type(tensor), storage._cdata), f, protocol=pickle_protocol)
|
||||
pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
|
||||
f.flush()
|
||||
tensor._write_metadata(f)
|
||||
|
||||
@ -112,7 +116,7 @@ def load(f, pickle_module=pickle):
|
||||
|
||||
extract('storages', lambda f, storage_type: storage_type._new_with_file(f))
|
||||
extract('tensors', lambda f, tensor_type, storage_id: \
|
||||
tensor_type._new_with_metadata_file(f, deserialized_objects[storage_id]))
|
||||
tensor_type._new_with_metadata_file(f, deserialized_objects.get(storage_id, None)))
|
||||
|
||||
pickle_file = tar.extractfile('pickle')
|
||||
unpickler = pickle_module.Unpickler(pickle_file)
|
||||
|
Reference in New Issue
Block a user