Files
pytorch/torch/serialization.py

348 lines
12 KiB
Python

import difflib
import inspect
import os
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
class SourceChangeWarning(Warning):
pass
def _add_to_tar(fn, tar_file, name):
tmp_file = tempfile.NamedTemporaryFile(delete=False)
fn(tmp_file)
tmp_file.close()
tar_file.add(tmp_file.name, arcname=name)
if os.path.isfile(tmp_file.name):
os.remove(tmp_file.name)
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
_package_registry = []
def register_package(priority, tagger, deserializer):
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
def _cpu_tag(obj):
if type(obj).__module__ == 'torch':
return 'cpu'
def _cuda_tag(obj):
if type(obj).__module__ == 'torch.cuda':
return 'cuda:' + str(obj.get_device())
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device_id = max(int(location[5:]), 0)
return obj.cuda(device_id)
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of " +
torch.typename(storage))
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " + location + ")")
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
Args:
obj: saved object
f: a file-like object (has to implement fileno that returns a file descriptor)
or a string containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
new_fd = False
if isinstance(f, str):
new_fd = True
f = open(f, "wb")
try:
return _save(obj, f, pickle_module, pickle_protocol)
finally:
if new_fd:
f.close()
def _save(obj, f, pickle_module, pickle_protocol):
import torch.nn as nn
serialized_tensors = {}
serialized_storages = {}
serialized_container_types = {}
def persistent_id(obj):
if isinstance(obj, type) and issubclass(obj, nn.Container):
if obj in serialized_container_types:
return None
serialized_container_types[obj] = True
source_file = source = None
try:
source_file = inspect.getsourcefile(obj)
source = inspect.getsource(obj)
except TypeError:
warnings.warn("Couldn't retrieve source code for container of "
"type " + obj.__name__ + ". It won't be checked "
"for correctness upon loading.")
return (obj, source_file, source)
if torch.is_tensor(obj):
serialized_tensors[obj._cdata] = obj
return str(obj._cdata)
elif torch.is_storage(obj):
serialized_storages[obj._cdata] = obj
return str(obj._cdata)
return None
def save_tensors(f):
pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol)
for key, tensor in serialized_tensors.items():
storage = tensor.storage()
if storage is not None:
storage_id = storage._cdata
serialized_storages[storage_id] = storage
else:
storage_id = None
pickle_module.dump((key, storage_id, type(tensor)), f,
protocol=pickle_protocol)
f.flush()
tensor._write_metadata(f)
def save_storages(f):
storage_views = []
storage_views_roots = {}
for key, storage in serialized_storages.items():
root, offset = storage._root_storage()
if root is not storage:
storage_views_roots[root._cdata] = root
storage_views.append((storage._cdata, root._cdata, offset,
storage.size()))
for view_info in storage_views:
del serialized_storages[view_info[0]]
serialized_storages.update(storage_views_roots)
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
for key, storage in serialized_storages.items():
location = location_tag(storage)
storage_type = normalize_storage_type(type(storage))
pickle_module.dump((key, location, storage_type), f,
protocol=pickle_protocol)
f.flush()
storage._write_file(f)
pickle_module.dump(storage_views, f, protocol=pickle_protocol)
def pickle_objects(f):
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
def save_sys_info(f):
sys_info = dict(
protocol_version=1000,
little_endian=sys.byteorder == 'little',
type_sizes = dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
with closing(tarfile.open(fileobj=f, mode='w:', format=tarfile.PAX_FORMAT)) as tar:
_add_to_tar(save_sys_info, tar, 'sys_info')
_add_to_tar(pickle_objects, tar, 'pickle')
_add_to_tar(save_tensors, tar, 'tensors')
_add_to_tar(save_storages, tar, 'storages')
def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with torch.save from a disk file.
torch.load can dynamically remap storages to be loaded on a different device
using the map_location argument. If it's a callable, it will be called with
two arguments: storage and location tag. It's expected to either return a
storage that's been moved to a different location, or None (and the location
will be resolved using the default method). If this argument is a dict it's
expected to be a mapping from location tags used in a file, to location
tags of the current system.
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
tagging and deserialization methods using register_package.
Args:
f: a file-like object (has to implement fileno that returns a file descriptor)
or a string containing a file name
map_location: a function or a dict specifying how to remap storage locations
pickle_module: module used for unpickling metadata and objects (has to match
the pickle_module used to serialize file)
"""
new_fd = False
if isinstance(f, str):
new_fd = True
f = open(f, 'rb')
try:
return _load(f, map_location, pickle_module)
finally:
if new_fd:
f.close()
def _load(f, map_location, pickle_module):
deserialized_objects = {}
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
if not result:
result = default_restore_location(storage, location)
return result
def _check_container_source(container_type, source_file, original_source):
current_source = inspect.getsource(container_type)
if original_source != current_source:
if container_type.dump_patches:
file_name = container_type.__name__ + '.patch'
diff = difflib.unified_diff(
current_source.split('\n'),
original_source.split('\n'),
source_file,
source_file, lineterm="")
lines = '\n'.join(diff)
try:
with open(file_name, 'a+') as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise IOError
msg = ("Saved a reverse patch to " + file_name + ". "
"Run `patch -p0 < " + file_name + "` to revert your "
"changes.")
except IOError:
msg = ("Tried to save a patch, but couldn't create a "
"writable file " + file_name + ". Make sure it "
"doesn't exist and your working directory is "
"writable.")
else:
msg = ("you can retrieve the original source code by "
"accessing the object's source attribute or set "
"`torch.nn.Container.dump_patches = True` and use the "
"patch tool to revert the changes.")
msg = ("source code of class '{}' has changed. {}"
.format(torch.typename(container_type), msg))
warnings.warn(msg, SourceChangeWarning)
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
# Ignore containers that don't have any sources saved
if all(saved_id[1:]):
_check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f)
for i in range(num_storages):
args = pickle_module.load(f)
key, location, storage_type = args
obj = storage_type._new_with_file(f)
obj = restore_location(obj, location)
deserialized_objects[key] = obj
storage_views = pickle_module.load(f)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
deserialized_objects[target_cdata] = root[offset:offset+size]
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f)
for i in range(num_tensors):
args = pickle_module.load(f)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
tensor_type = storage_to_tensor_type(storage)
tensor = tensor_type._new_with_metadata_file(f, storage)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result