Allow remapping storages at load time and serialize data in little endian order

This commit is contained in:
Adam Paszke
2016-10-04 11:33:00 -07:00
parent 53c65ddc6a
commit 0c9670ddf0
7 changed files with 325 additions and 15 deletions

View File

@ -12,6 +12,7 @@ else:
import pickle
import torch
from ._utils import _import_dotted_name
DEFAULT_PROTOCOL = 2
@ -36,8 +37,78 @@ def mkdtemp():
shutil.rmtree(path)
_package_registry = []
def register_package(priority, tagger, deserializer):
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
def _cpu_tag(obj):
if type(obj).__module__ == 'torch':
return 'cpu'
def _cuda_tag(obj):
if type(obj).__module__ == 'torch.cuda':
return 'cuda:' + str(obj.get_device())
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device_id = max(int(location[5:]), 0)
return obj.cuda(device_id)
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of " +
torch.typename(storage))
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " + location + ")")
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
# TODO: choose pickle protocol
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
Args:
obj: saved object
f: a file-like object (has to implement fileno that returns a file descriptor)
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
serialized_tensors = {}
serialized_storages = {}
@ -60,7 +131,8 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
else:
storage_id = None
pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
pickle_module.dump((key, storage_id, type(tensor)), f,
protocol=pickle_protocol)
f.flush()
tensor._write_metadata(f)
@ -80,7 +152,10 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
for key, storage in serialized_storages.items():
pickle_module.dump((key, type(storage)), f, protocol=pickle_protocol)
location = location_tag(storage)
storage_type = normalize_storage_type(type(storage))
pickle_module.dump((key, location, storage_type), f,
protocol=pickle_protocol)
f.flush()
storage._write_file(f)
@ -110,26 +185,58 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
_add_to_tar(save_storages, tar, 'storages')
def load(f, pickle_module=pickle):
def load(f, map_location=None, pickle_module=pickle):
"""Loads an object saved with torch.save from a disk file.
torch.load can dynamically remap storages to be loaded on a different device
using the map_location argument. If it's a callable, it will be called with
two arguments: storage and location tag. It's expected to either return a
storage that's been moved to a different location, or None (and the location
will be resolved using the default method). If this argument is a dict it's
expected to be a mapping from location tags used in a file, to location
tags of the current system.
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
tagging and deserialization methods using register_package.
Args:
f: a file-like object (has to implement fileno that returns a file descriptor)
map_location: a function or a dict specifying how to remap storage locations
pickle_module: module used for unpickling metadata and objects (has to match
the pickle_module used to serialize file)
"""
deserialized_objects = {}
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
if not result:
result = default_restore_location(storage, location)
return result
def persistent_load(saved_id):
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
def extract(f, init):
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f)
for i in range(num_storages):
args = pickle_module.load(f)
key, args = args[0], args[1:]
obj = init(*args)
key, location, storage_type = args
obj = storage_type._new_with_file(f)
obj = restore_location(obj, location)
deserialized_objects[key] = obj
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
extract(f, lambda storage_type: storage_type._new_with_file(f))
storage_views = pickle_module.load(f)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
@ -137,10 +244,17 @@ def load(f, pickle_module=pickle):
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
def deserialize_tensor(tensor_type, storage_id):
num_tensors = pickle_module.load(f)
for i in range(num_tensors):
args = pickle_module.load(f)
key, storage_id, original_tensor_type = args
storage = deserialized_objects.get(storage_id, None)
return tensor_type._new_with_metadata_file(f, storage)
extract(f, deserialize_tensor)
if storage:
tensor_type = storage_to_tensor_type(storage)
tensor = tensor_type._new_with_metadata_file(f, storage)
else:
tensor = original_tensor_type._new_with_metadata_file(f, storage)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)