mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Here's the command I used to invoke autopep8 (in parallel!): git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i Several rules are ignored in setup.cfg. The goal is to let autopep8 handle everything which it can handle safely, and to disable any rules which are tricky or controversial to address. We may want to come back and re-enable some of these rules later, but I'm trying to make this patch as safe as possible. Also configures flake8 to match pep8's behavior. Also configures TravisCI to check the whole project for lint.
348 lines
12 KiB
Python
348 lines
12 KiB
Python
import difflib
|
|
import inspect
|
|
import os
|
|
import shutil
|
|
import struct
|
|
import sys
|
|
import torch
|
|
import tarfile
|
|
import tempfile
|
|
import warnings
|
|
from contextlib import closing, contextmanager
|
|
from ._utils import _import_dotted_name
|
|
if sys.version_info[0] == 2:
|
|
import cPickle as pickle
|
|
else:
|
|
import pickle
|
|
|
|
DEFAULT_PROTOCOL = 2
|
|
|
|
LONG_SIZE = struct.Struct('=l').size
|
|
INT_SIZE = struct.Struct('=i').size
|
|
SHORT_SIZE = struct.Struct('=h').size
|
|
|
|
|
|
class SourceChangeWarning(Warning):
|
|
pass
|
|
|
|
|
|
def _add_to_tar(fn, tar_file, name):
|
|
tmp_file = tempfile.NamedTemporaryFile(delete=False)
|
|
fn(tmp_file)
|
|
tmp_file.close()
|
|
|
|
tar_file.add(tmp_file.name, arcname=name)
|
|
if os.path.isfile(tmp_file.name):
|
|
os.remove(tmp_file.name)
|
|
|
|
|
|
@contextmanager
|
|
def mkdtemp():
|
|
path = tempfile.mkdtemp()
|
|
yield path
|
|
shutil.rmtree(path)
|
|
|
|
|
|
_package_registry = []
|
|
|
|
|
|
def register_package(priority, tagger, deserializer):
|
|
queue_elem = (priority, tagger, deserializer)
|
|
_package_registry.append(queue_elem)
|
|
_package_registry.sort()
|
|
|
|
|
|
def _cpu_tag(obj):
|
|
if type(obj).__module__ == 'torch':
|
|
return 'cpu'
|
|
|
|
|
|
def _cuda_tag(obj):
|
|
if type(obj).__module__ == 'torch.cuda':
|
|
return 'cuda:' + str(obj.get_device())
|
|
|
|
|
|
def _cpu_deserialize(obj, location):
|
|
if location == 'cpu':
|
|
return obj
|
|
|
|
|
|
def _cuda_deserialize(obj, location):
|
|
if location.startswith('cuda'):
|
|
device_id = max(int(location[5:]), 0)
|
|
return obj.cuda(device_id)
|
|
|
|
|
|
register_package(10, _cpu_tag, _cpu_deserialize)
|
|
register_package(20, _cuda_tag, _cuda_deserialize)
|
|
|
|
|
|
def location_tag(storage):
|
|
for _, tagger, _ in _package_registry:
|
|
location = tagger(storage)
|
|
if location:
|
|
return location
|
|
raise RuntimeError("don't know how to determine data location of " +
|
|
torch.typename(storage))
|
|
|
|
|
|
def default_restore_location(storage, location):
|
|
for _, _, fn in _package_registry:
|
|
result = fn(storage, location)
|
|
if result is not None:
|
|
return result
|
|
raise RuntimeError("don't know how to restore data location of " +
|
|
torch.typename(storage) + " (tagged with " + location + ")")
|
|
|
|
|
|
def normalize_storage_type(storage_type):
|
|
return getattr(torch, storage_type.__name__)
|
|
|
|
|
|
def storage_to_tensor_type(storage):
|
|
storage_type = type(storage)
|
|
module = _import_dotted_name(storage_type.__module__)
|
|
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
|
|
|
|
|
|
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
|
"""Saves an object to a disk file.
|
|
|
|
Args:
|
|
obj: saved object
|
|
f: a file-like object (has to implement fileno that returns a file descriptor)
|
|
or a string containing a file name
|
|
pickle_module: module used for pickling metadata and objects
|
|
pickle_protocol: can be specified to override the default protocol
|
|
"""
|
|
new_fd = False
|
|
if isinstance(f, str):
|
|
new_fd = True
|
|
f = open(f, "wb")
|
|
try:
|
|
return _save(obj, f, pickle_module, pickle_protocol)
|
|
finally:
|
|
if new_fd:
|
|
f.close()
|
|
|
|
|
|
def _save(obj, f, pickle_module, pickle_protocol):
|
|
import torch.nn as nn
|
|
serialized_tensors = {}
|
|
serialized_storages = {}
|
|
serialized_container_types = {}
|
|
|
|
def persistent_id(obj):
|
|
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
|
if obj in serialized_container_types:
|
|
return None
|
|
serialized_container_types[obj] = True
|
|
source_file = source = None
|
|
try:
|
|
source_file = inspect.getsourcefile(obj)
|
|
source = inspect.getsource(obj)
|
|
except (TypeError, IOError):
|
|
warnings.warn("Couldn't retrieve source code for container of "
|
|
"type " + obj.__name__ + ". It won't be checked "
|
|
"for correctness upon loading.")
|
|
return (obj, source_file, source)
|
|
if torch.is_tensor(obj):
|
|
serialized_tensors[obj._cdata] = obj
|
|
return str(obj._cdata)
|
|
elif torch.is_storage(obj):
|
|
serialized_storages[obj._cdata] = obj
|
|
return str(obj._cdata)
|
|
return None
|
|
|
|
def save_tensors(f):
|
|
pickle_module.dump(len(serialized_tensors), f, protocol=pickle_protocol)
|
|
for key, tensor in serialized_tensors.items():
|
|
storage = tensor.storage()
|
|
if storage is not None:
|
|
storage_id = storage._cdata
|
|
serialized_storages[storage_id] = storage
|
|
else:
|
|
storage_id = None
|
|
|
|
pickle_module.dump((key, storage_id, type(tensor)), f,
|
|
protocol=pickle_protocol)
|
|
f.flush()
|
|
tensor._write_metadata(f)
|
|
|
|
def save_storages(f):
|
|
storage_views = []
|
|
storage_views_roots = {}
|
|
|
|
for key, storage in serialized_storages.items():
|
|
root, offset = storage._root_storage()
|
|
if root is not storage:
|
|
storage_views_roots[root._cdata] = root
|
|
storage_views.append((storage._cdata, root._cdata, offset,
|
|
storage.size()))
|
|
for view_info in storage_views:
|
|
del serialized_storages[view_info[0]]
|
|
serialized_storages.update(storage_views_roots)
|
|
|
|
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
|
|
for key, storage in serialized_storages.items():
|
|
location = location_tag(storage)
|
|
storage_type = normalize_storage_type(type(storage))
|
|
pickle_module.dump((key, location, storage_type), f,
|
|
protocol=pickle_protocol)
|
|
f.flush()
|
|
storage._write_file(f)
|
|
|
|
pickle_module.dump(storage_views, f, protocol=pickle_protocol)
|
|
|
|
def pickle_objects(f):
|
|
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
|
|
pickler.persistent_id = persistent_id
|
|
pickler.dump(obj)
|
|
|
|
def save_sys_info(f):
|
|
sys_info = dict(
|
|
protocol_version=1000,
|
|
little_endian=sys.byteorder == 'little',
|
|
type_sizes=dict(
|
|
short=SHORT_SIZE,
|
|
int=INT_SIZE,
|
|
long=LONG_SIZE,
|
|
),
|
|
)
|
|
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
|
|
|
|
with closing(tarfile.open(fileobj=f, mode='w:', format=tarfile.PAX_FORMAT)) as tar:
|
|
_add_to_tar(save_sys_info, tar, 'sys_info')
|
|
_add_to_tar(pickle_objects, tar, 'pickle')
|
|
_add_to_tar(save_tensors, tar, 'tensors')
|
|
_add_to_tar(save_storages, tar, 'storages')
|
|
|
|
|
|
def load(f, map_location=None, pickle_module=pickle):
|
|
"""Loads an object saved with torch.save from a disk file.
|
|
|
|
torch.load can dynamically remap storages to be loaded on a different device
|
|
using the map_location argument. If it's a callable, it will be called with
|
|
two arguments: storage and location tag. It's expected to either return a
|
|
storage that's been moved to a different location, or None (and the location
|
|
will be resolved using the default method). If this argument is a dict it's
|
|
expected to be a mapping from location tags used in a file, to location
|
|
tags of the current system.
|
|
|
|
By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
|
|
(e.g. 'cuda:2') for cuda tensors. User extensions can register their own
|
|
tagging and deserialization methods using register_package.
|
|
|
|
Args:
|
|
f: a file-like object (has to implement fileno that returns a file descriptor)
|
|
or a string containing a file name
|
|
map_location: a function or a dict specifying how to remap storage locations
|
|
pickle_module: module used for unpickling metadata and objects (has to match
|
|
the pickle_module used to serialize file)
|
|
"""
|
|
new_fd = False
|
|
if isinstance(f, str):
|
|
new_fd = True
|
|
f = open(f, 'rb')
|
|
try:
|
|
return _load(f, map_location, pickle_module)
|
|
finally:
|
|
if new_fd:
|
|
f.close()
|
|
|
|
|
|
def _load(f, map_location, pickle_module):
|
|
deserialized_objects = {}
|
|
|
|
if map_location is None:
|
|
restore_location = default_restore_location
|
|
elif isinstance(map_location, dict):
|
|
def restore_location(storage, location):
|
|
location = map_location.get(location, location)
|
|
return default_restore_location(storage, location)
|
|
else:
|
|
def restore_location(storage, location):
|
|
result = map_location(storage, location)
|
|
if not result:
|
|
result = default_restore_location(storage, location)
|
|
return result
|
|
|
|
def _check_container_source(container_type, source_file, original_source):
|
|
current_source = inspect.getsource(container_type)
|
|
if original_source != current_source:
|
|
if container_type.dump_patches:
|
|
file_name = container_type.__name__ + '.patch'
|
|
diff = difflib.unified_diff(
|
|
current_source.split('\n'),
|
|
original_source.split('\n'),
|
|
source_file,
|
|
source_file, lineterm="")
|
|
lines = '\n'.join(diff)
|
|
try:
|
|
with open(file_name, 'a+') as f:
|
|
file_size = f.seek(0, 2)
|
|
f.seek(0)
|
|
if file_size == 0:
|
|
f.write(lines)
|
|
elif file_size != len(lines) or f.read() != lines:
|
|
raise IOError
|
|
msg = ("Saved a reverse patch to " + file_name + ". "
|
|
"Run `patch -p0 < " + file_name + "` to revert your "
|
|
"changes.")
|
|
except IOError:
|
|
msg = ("Tried to save a patch, but couldn't create a "
|
|
"writable file " + file_name + ". Make sure it "
|
|
"doesn't exist and your working directory is "
|
|
"writable.")
|
|
else:
|
|
msg = ("you can retrieve the original source code by "
|
|
"accessing the object's source attribute or set "
|
|
"`torch.nn.Module.dump_patches = True` and use the "
|
|
"patch tool to revert the changes.")
|
|
msg = ("source code of class '{}' has changed. {}"
|
|
.format(torch.typename(container_type), msg))
|
|
warnings.warn(msg, SourceChangeWarning)
|
|
|
|
def persistent_load(saved_id):
|
|
if isinstance(saved_id, tuple):
|
|
# Ignore containers that don't have any sources saved
|
|
if all(saved_id[1:]):
|
|
_check_container_source(*saved_id)
|
|
return saved_id[0]
|
|
return deserialized_objects[int(saved_id)]
|
|
|
|
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
|
|
mkdtemp() as tmpdir:
|
|
|
|
tar.extract('storages', path=tmpdir)
|
|
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
|
|
num_storages = pickle_module.load(f)
|
|
for i in range(num_storages):
|
|
args = pickle_module.load(f)
|
|
key, location, storage_type = args
|
|
obj = storage_type._new_with_file(f)
|
|
obj = restore_location(obj, location)
|
|
deserialized_objects[key] = obj
|
|
|
|
storage_views = pickle_module.load(f)
|
|
for target_cdata, root_cdata, offset, size in storage_views:
|
|
root = deserialized_objects[root_cdata]
|
|
deserialized_objects[target_cdata] = root[offset:offset + size]
|
|
|
|
tar.extract('tensors', path=tmpdir)
|
|
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
|
|
num_tensors = pickle_module.load(f)
|
|
for i in range(num_tensors):
|
|
args = pickle_module.load(f)
|
|
key, storage_id, original_tensor_type = args
|
|
storage = deserialized_objects[storage_id]
|
|
tensor_type = storage_to_tensor_type(storage)
|
|
tensor = tensor_type._new_with_metadata_file(f, storage)
|
|
deserialized_objects[key] = tensor
|
|
|
|
pickle_file = tar.extractfile('pickle')
|
|
unpickler = pickle_module.Unpickler(pickle_file)
|
|
unpickler.persistent_load = persistent_load
|
|
result = unpickler.load()
|
|
return result
|