Fix nn serialization errors

This commit is contained in:
Adam Paszke
2016-09-15 18:46:47 -07:00
parent 95d545e75b
commit d1fda539b7
9 changed files with 71 additions and 26 deletions

View File

@ -54,9 +54,13 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol)
for key, tensor in serialized_tensors.items():
storage = tensor.storage()
serialized_storages[storage._cdata] = storage
if storage is not None:
storage_id = storage._cdata
serialized_storages[storage_id] = storage
else:
storage_id = None
pickle_module.dump((key, type(tensor), storage._cdata), f, protocol=pickle_protocol)
pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
f.flush()
tensor._write_metadata(f)
@ -112,7 +116,7 @@ def load(f, pickle_module=pickle):
extract('storages', lambda f, storage_type: storage_type._new_with_file(f))
extract('tensors', lambda f, tensor_type, storage_id: \
tensor_type._new_with_metadata_file(f, deserialized_objects[storage_id]))
tensor_type._new_with_metadata_file(f, deserialized_objects.get(storage_id, None)))
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)