mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Allow remapping storages at load time and serialize data in little endian order
This commit is contained in:
@ -12,6 +12,7 @@ else:
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from ._utils import _import_dotted_name
|
||||
|
||||
DEFAULT_PROTOCOL = 2
|
||||
|
||||
@ -36,8 +37,78 @@ def mkdtemp():
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
_package_registry = []
|
||||
|
||||
|
||||
def register_package(priority, tagger, deserializer):
|
||||
queue_elem = (priority, tagger, deserializer)
|
||||
_package_registry.append(queue_elem)
|
||||
_package_registry.sort()
|
||||
|
||||
|
||||
def _cpu_tag(obj):
|
||||
if type(obj).__module__ == 'torch':
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def _cuda_tag(obj):
|
||||
if type(obj).__module__ == 'torch.cuda':
|
||||
return 'cuda:' + str(obj.get_device())
|
||||
|
||||
|
||||
def _cpu_deserialize(obj, location):
|
||||
if location == 'cpu':
|
||||
return obj
|
||||
|
||||
|
||||
def _cuda_deserialize(obj, location):
|
||||
if location.startswith('cuda'):
|
||||
device_id = max(int(location[5:]), 0)
|
||||
return obj.cuda(device_id)
|
||||
|
||||
|
||||
register_package(10, _cpu_tag, _cpu_deserialize)
|
||||
register_package(20, _cuda_tag, _cuda_deserialize)
|
||||
|
||||
|
||||
def location_tag(storage):
|
||||
for _, tagger, _ in _package_registry:
|
||||
location = tagger(storage)
|
||||
if location:
|
||||
return location
|
||||
raise RuntimeError("don't know how to determine data location of " +
|
||||
torch.typename(storage))
|
||||
|
||||
|
||||
def default_restore_location(storage, location):
|
||||
for _, _, fn in _package_registry:
|
||||
result = fn(storage, location)
|
||||
if result is not None:
|
||||
return result
|
||||
raise RuntimeError("don't know how to restore data location of " +
|
||||
torch.typename(storage) + " (tagged with " + location + ")")
|
||||
|
||||
|
||||
def normalize_storage_type(storage_type):
|
||||
return getattr(torch, storage_type.__name__)
|
||||
|
||||
|
||||
def storage_to_tensor_type(storage):
|
||||
storage_type = type(storage)
|
||||
module = _import_dotted_name(storage_type.__module__)
|
||||
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
|
||||
|
||||
|
||||
# TODO: choose pickle protocol
|
||||
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
"""Saves an object to a disk file.
|
||||
|
||||
Args:
|
||||
obj: saved object
|
||||
f: a file-like object (has to implement fileno that returns a file descriptor)
|
||||
pickle_module: module used for pickling metadata and objects
|
||||
pickle_protocol: can be specified to override the default protocol
|
||||
"""
|
||||
serialized_tensors = {}
|
||||
serialized_storages = {}
|
||||
|
||||
@ -60,7 +131,8 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
else:
|
||||
storage_id = None
|
||||
|
||||
pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
|
||||
pickle_module.dump((key, storage_id, type(tensor)), f,
|
||||
protocol=pickle_protocol)
|
||||
f.flush()
|
||||
tensor._write_metadata(f)
|
||||
|
||||
@ -80,7 +152,10 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
|
||||
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
|
||||
for key, storage in serialized_storages.items():
|
||||
pickle_module.dump((key, type(storage)), f, protocol=pickle_protocol)
|
||||
location = location_tag(storage)
|
||||
storage_type = normalize_storage_type(type(storage))
|
||||
pickle_module.dump((key, location, storage_type), f,
|
||||
protocol=pickle_protocol)
|
||||
f.flush()
|
||||
storage._write_file(f)
|
||||
|
||||
@ -110,26 +185,58 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
_add_to_tar(save_storages, tar, 'storages')
|
||||
|
||||
|
||||
def load(f, pickle_module=pickle):
|
||||
def load(f, map_location=None, pickle_module=pickle):
|
||||
"""Loads an object saved with torch.save from a disk file.
|
||||
|
||||
torch.load can dynamically remap storages to be loaded on a different device
|
||||
using the map_location argument. If it's a callable, it will be called with
|
||||
two arguments: storage and location tag. It's expected to either return a
|
||||
storage that's been moved to a different location, or None (and the location
|
||||
will be resolved using the default method). If this argument is a dict it's
|
||||
expected to be a mapping from location tags used in a file, to location
|
||||
tags of the current system.
|
||||
|
||||
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
|
||||
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
|
||||
tagging and deserialization methods using register_package.
|
||||
|
||||
Args:
|
||||
f: a file-like object (has to implement fileno that returns a file descriptor)
|
||||
map_location: a function or a dict specifying how to remap storage locations
|
||||
pickle_module: module used for unpickling metadata and objects (has to match
|
||||
the pickle_module used to serialize file)
|
||||
"""
|
||||
deserialized_objects = {}
|
||||
|
||||
if map_location is None:
|
||||
restore_location = default_restore_location
|
||||
elif isinstance(map_location, dict):
|
||||
def restore_location(storage, location):
|
||||
location = map_location.get(location, location)
|
||||
return default_restore_location(storage, location)
|
||||
else:
|
||||
def restore_location(storage, location):
|
||||
result = map_location(storage, location)
|
||||
if not result:
|
||||
result = default_restore_location(storage, location)
|
||||
return result
|
||||
|
||||
def persistent_load(saved_id):
|
||||
return deserialized_objects[int(saved_id)]
|
||||
|
||||
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
|
||||
mkdtemp() as tmpdir:
|
||||
|
||||
def extract(f, init):
|
||||
tar.extract('storages', path=tmpdir)
|
||||
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
|
||||
num_storages = pickle_module.load(f)
|
||||
for i in range(num_storages):
|
||||
args = pickle_module.load(f)
|
||||
key, args = args[0], args[1:]
|
||||
obj = init(*args)
|
||||
key, location, storage_type = args
|
||||
obj = storage_type._new_with_file(f)
|
||||
obj = restore_location(obj, location)
|
||||
deserialized_objects[key] = obj
|
||||
|
||||
tar.extract('storages', path=tmpdir)
|
||||
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
|
||||
extract(f, lambda storage_type: storage_type._new_with_file(f))
|
||||
storage_views = pickle_module.load(f)
|
||||
for target_cdata, root_cdata, offset, size in storage_views:
|
||||
root = deserialized_objects[root_cdata]
|
||||
@ -137,10 +244,17 @@ def load(f, pickle_module=pickle):
|
||||
|
||||
tar.extract('tensors', path=tmpdir)
|
||||
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
|
||||
def deserialize_tensor(tensor_type, storage_id):
|
||||
num_tensors = pickle_module.load(f)
|
||||
for i in range(num_tensors):
|
||||
args = pickle_module.load(f)
|
||||
key, storage_id, original_tensor_type = args
|
||||
storage = deserialized_objects.get(storage_id, None)
|
||||
return tensor_type._new_with_metadata_file(f, storage)
|
||||
extract(f, deserialize_tensor)
|
||||
if storage:
|
||||
tensor_type = storage_to_tensor_type(storage)
|
||||
tensor = tensor_type._new_with_metadata_file(f, storage)
|
||||
else:
|
||||
tensor = original_tensor_type._new_with_metadata_file(f, storage)
|
||||
deserialized_objects[key] = tensor
|
||||
|
||||
pickle_file = tar.extractfile('pickle')
|
||||
unpickler = pickle_module.Unpickler(pickle_file)
|
||||
|
Reference in New Issue
Block a user