mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Clean up file opening for serialization (#29221)
Summary: Stacked PRs * https://github.com/pytorch/pytorch/issues/29232 - Add zipfile serialization * https://github.com/pytorch/pytorch/issues/29228 - Expose miniz to Python * **https://github.com/pytorch/pytorch/issues/29221 - Clean up file opening for serialization** This is a small refactor to get things started for zipfile-based serialization Pull Request resolved: https://github.com/pytorch/pytorch/pull/29221 Differential Revision: D18330932 Pulled By: driazati fbshipit-source-id: ce91542faf987ae5aa6dfd322e633a0c7335e678
This commit is contained in:
committed by
Facebook Github Bot
parent
ae12630508
commit
fff4f16e45
@ -171,22 +171,40 @@ def storage_to_tensor_type(storage):
|
||||
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
|
||||
|
||||
|
||||
def _with_file_like(f, mode, body):
|
||||
"""
|
||||
Executes a body function with a file object for f, opening
|
||||
it in 'mode' if it is a string filename.
|
||||
"""
|
||||
new_fd = False
|
||||
if isinstance(f, str) or \
|
||||
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
|
||||
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||
new_fd = True
|
||||
f = open(f, mode)
|
||||
try:
|
||||
return body(f)
|
||||
finally:
|
||||
if new_fd:
|
||||
f.close()
|
||||
class _open_file_like(object):
|
||||
def __init__(self, name_or_buffer, mode):
|
||||
self.is_path = _open_file_like.is_path(name_or_buffer)
|
||||
if self.is_path:
|
||||
self.fname = name_or_buffer
|
||||
else:
|
||||
self.buffer = name_or_buffer
|
||||
self.mode = mode
|
||||
|
||||
def open_path(self):
|
||||
return open(self.fname, self.mode)
|
||||
|
||||
def open_handle(self):
|
||||
return self.buffer
|
||||
|
||||
@staticmethod
|
||||
def is_path(name_or_buffer):
|
||||
return isinstance(name_or_buffer, str) or \
|
||||
(sys.version_info[0] == 2 and isinstance(name_or_buffer, unicode)) or \
|
||||
(sys.version_info[0] == 3 and isinstance(name_or_buffer, pathlib.Path))
|
||||
|
||||
def __enter__(self):
|
||||
if self.is_path:
|
||||
file_like = self.open_path()
|
||||
else:
|
||||
file_like = self.open_handle()
|
||||
|
||||
self.file_like = file_like
|
||||
return self.file_like
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.is_path:
|
||||
# Only close the `file_like` if it was opened by this object
|
||||
self.file_like.close()
|
||||
|
||||
|
||||
def _is_compressed_file(f):
|
||||
@ -258,7 +276,8 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
|
||||
>>> buffer = io.BytesIO()
|
||||
>>> torch.save(x, buffer)
|
||||
"""
|
||||
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
|
||||
with _open_file_like(f, 'wb') as opened_file:
|
||||
_save(obj, opened_file, pickle_module, pickle_protocol)
|
||||
|
||||
|
||||
def _save(obj, f, pickle_module, pickle_protocol):
|
||||
@ -414,21 +433,11 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
# Load a module with 'ascii' encoding for unpickling
|
||||
>>> torch.load('module.pt', encoding='ascii')
|
||||
"""
|
||||
new_fd = False
|
||||
if isinstance(f, str) or \
|
||||
(sys.version_info[0] == 2 and isinstance(f, unicode)):
|
||||
new_fd = True
|
||||
f = open(f, 'rb')
|
||||
elif (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
|
||||
new_fd = True
|
||||
f = f.open('rb')
|
||||
try:
|
||||
if sys.version_info >= (3, 0) and 'encoding' not in pickle_load_args.keys():
|
||||
pickle_load_args['encoding'] = 'utf-8'
|
||||
return _load(f, map_location, pickle_module, **pickle_load_args)
|
||||
finally:
|
||||
if new_fd:
|
||||
f.close()
|
||||
if sys.version_info >= (3, 0) and 'encoding' not in pickle_load_args.keys():
|
||||
pickle_load_args['encoding'] = 'utf-8'
|
||||
|
||||
with _open_file_like(f, 'rb') as opened_file:
|
||||
return _load(opened_file, map_location, pickle_module, **pickle_load_args)
|
||||
|
||||
|
||||
# Register pickling support for layout instances such as
|
||||
|
Reference in New Issue
Block a user