Clean up file opening for serialization (#29221)

Summary:
Stacked PRs
 * https://github.com/pytorch/pytorch/issues/29232 - Add zipfile serialization
 * https://github.com/pytorch/pytorch/issues/29228 - Expose miniz to Python
 * **https://github.com/pytorch/pytorch/issues/29221 - Clean up file opening for serialization**

This is a small refactor to get things started for zipfile-based serialization
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29221

Differential Revision: D18330932

Pulled By: driazati

fbshipit-source-id: ce91542faf987ae5aa6dfd322e633a0c7335e678
This commit is contained in:
Your Name
2019-11-06 18:40:10 -08:00
committed by Facebook Github Bot
parent ae12630508
commit fff4f16e45
2 changed files with 47 additions and 37 deletions

View File

@ -171,22 +171,40 @@ def storage_to_tensor_type(storage):
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
def _with_file_like(f, mode, body):
"""
Executes a body function with a file object for f, opening
it in 'mode' if it is a string filename.
"""
new_fd = False
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
new_fd = True
f = open(f, mode)
try:
return body(f)
finally:
if new_fd:
f.close()
class _open_file_like(object):
def __init__(self, name_or_buffer, mode):
self.is_path = _open_file_like.is_path(name_or_buffer)
if self.is_path:
self.fname = name_or_buffer
else:
self.buffer = name_or_buffer
self.mode = mode
def open_path(self):
return open(self.fname, self.mode)
def open_handle(self):
return self.buffer
@staticmethod
def is_path(name_or_buffer):
return isinstance(name_or_buffer, str) or \
(sys.version_info[0] == 2 and isinstance(name_or_buffer, unicode)) or \
(sys.version_info[0] == 3 and isinstance(name_or_buffer, pathlib.Path))
def __enter__(self):
if self.is_path:
file_like = self.open_path()
else:
file_like = self.open_handle()
self.file_like = file_like
return self.file_like
def __exit__(self, *args):
if self.is_path:
# Only close the `file_like` if it was opened by this object
self.file_like.close()
def _is_compressed_file(f):
@ -258,7 +276,8 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)
"""
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
with _open_file_like(f, 'wb') as opened_file:
_save(obj, opened_file, pickle_module, pickle_protocol)
def _save(obj, f, pickle_module, pickle_protocol):
@ -414,21 +433,11 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
# Load a module with 'ascii' encoding for unpickling
>>> torch.load('module.pt', encoding='ascii')
"""
new_fd = False
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)):
new_fd = True
f = open(f, 'rb')
elif (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
new_fd = True
f = f.open('rb')
try:
if sys.version_info >= (3, 0) and 'encoding' not in pickle_load_args.keys():
pickle_load_args['encoding'] = 'utf-8'
return _load(f, map_location, pickle_module, **pickle_load_args)
finally:
if new_fd:
f.close()
if sys.version_info >= (3, 0) and 'encoding' not in pickle_load_args.keys():
pickle_load_args['encoding'] = 'utf-8'
with _open_file_like(f, 'rb') as opened_file:
return _load(opened_file, map_location, pickle_module, **pickle_load_args)
# Register pickling support for layout instances such as