mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix some bugs with zipfile serialization (#32244)
Summary: Stacked PRs * #32958 - Make zip serialization the default * **#32244 - Fix some bugs with zipfile serialization** It includes the following changes: * Split up tests so that we can test both serialization methods * Loading something within a buffer doesn't work anymore, so those tests are only on the old serialization method (it's possible but introduces a big slowdown since it requires a linear scan of the entire zipfile to find the magic number at the end) * Call `readinto` on a buffer if possible instead of `read` + a copy * Disable CRC-32 checks on read (there was some issue where miniz said the CRC was wrong but `zipinfo` and `unzip` said the zip file was fine) ](https://our.intern.facebook.com/intern/diff/19418935/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/32244 Pulled By: driazati Reviewed By: eellison Differential Revision: D19418935 fbshipit-source-id: df140854f52ecd04236225417d625374fd99f573
This commit is contained in:
committed by
Facebook Github Bot
parent
ab75d64e6e
commit
74ce3a032c
@ -1442,6 +1442,8 @@ if (NOT INTERN_BUILD_MOBILE)
|
||||
ENDIF(HAVE_MALLOC_USABLE_SIZE)
|
||||
ENDIF(UNIX)
|
||||
|
||||
ADD_DEFINITIONS(-DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS)
|
||||
|
||||
# Is __thread supported?
|
||||
IF(NOT MSVC)
|
||||
CHECK_C_SOURCE_COMPILES("static __thread int x = 1; int main() { return x; }" C_HAS_THREAD)
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
import random
|
||||
import string
|
||||
import unittest
|
||||
import io
|
||||
try:
|
||||
import unittest.mock as mock
|
||||
except ImportError:
|
||||
@ -5804,11 +5805,6 @@ class TestNN(NNTestCase):
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
|
||||
def test_RNN_dropout_state(self):
|
||||
import sys
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
for p in (0, 0.1234):
|
||||
for train in (True, False):
|
||||
for cuda in (True, False):
|
||||
@ -5829,8 +5825,10 @@ class TestNN(NNTestCase):
|
||||
output1, hy1 = rnn(input, hx)
|
||||
output2, hy2 = rnn(input, hx)
|
||||
|
||||
rnn_pickle = pickle.dumps(rnn)
|
||||
rnn2 = pickle.loads(rnn_pickle)
|
||||
buf = io.BytesIO()
|
||||
rnn_pickle = torch.save(rnn, buf)
|
||||
buf.seek(0)
|
||||
rnn2 = torch.load(buf)
|
||||
rnn2.flatten_parameters()
|
||||
output3, hy3 = rnn2(input, hx)
|
||||
|
||||
|
@ -71,7 +71,7 @@ class FilelikeMock(object):
|
||||
return name in self.calls
|
||||
|
||||
|
||||
class TestSerialization(TestCase):
|
||||
class SerializationMixin(object):
|
||||
def _test_serialization_data(self):
|
||||
a = [torch.randn(5, 5).float() for i in range(2)]
|
||||
b = [a[i % 2] for i in range(4)] # 0-3
|
||||
@ -141,26 +141,6 @@ class TestSerialization(TestCase):
|
||||
|
||||
test(io.BytesIO())
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_serialization_zipfile(self):
|
||||
data = self._test_serialization_data()
|
||||
|
||||
def test(name_or_buffer):
|
||||
torch.save(data, name_or_buffer, _use_new_zipfile_serialization=True)
|
||||
|
||||
if hasattr(name_or_buffer, 'seek'):
|
||||
name_or_buffer.seek(0)
|
||||
|
||||
result = torch.load(name_or_buffer)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
test(f)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
test(f.name)
|
||||
|
||||
test(io.BytesIO())
|
||||
|
||||
def test_serialization(self):
|
||||
# Test serialization with a real file
|
||||
b = self._test_serialization_data()
|
||||
@ -235,6 +215,7 @@ class TestSerialization(TestCase):
|
||||
self.assertFalse(torch.serialization._is_zipfile(f))
|
||||
self.assertEqual(torch.load(f.name), t)
|
||||
|
||||
@unittest.skipIf(not PY3, "gzip doesn't support os.seek(0, os.SEEK_END) on Python 2")
|
||||
def test_serialization_gzip(self):
|
||||
# Test serialization with gzip file
|
||||
b = self._test_serialization_data()
|
||||
@ -248,30 +229,6 @@ class TestSerialization(TestCase):
|
||||
c = torch.load(f)
|
||||
self._test_serialization_assert(b, c)
|
||||
|
||||
def test_serialization_offset(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
|
||||
m = torch.nn.Conv2d(1, 1, (1, 3))
|
||||
i, j = 41, 43
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
pickle.dump(j, f)
|
||||
torch.save(b, f)
|
||||
torch.save(m, f)
|
||||
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
|
||||
f.seek(0)
|
||||
i_loaded = pickle.load(f)
|
||||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
m_loaded = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1,
|
||||
'"dill" not found or is correct version'
|
||||
@ -304,26 +261,7 @@ class TestSerialization(TestCase):
|
||||
self.assertIsInstance(x3, type(x))
|
||||
self.assertEqual(x, x3)
|
||||
|
||||
def test_serialization_offset_filelike(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
|
||||
i, j = 41, 43
|
||||
with BytesIOContext() as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
pickle.dump(j, f)
|
||||
torch.save(b, f)
|
||||
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
|
||||
f.seek(0)
|
||||
i_loaded = pickle.load(f)
|
||||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
@unittest.skipIf(not PY3, "gzip doesn't support os.seek(0, os.SEEK_END) on Python 2")
|
||||
def test_serialization_offset_gzip(self):
|
||||
a = torch.randn(5, 5)
|
||||
i = 41
|
||||
@ -405,57 +343,6 @@ class TestSerialization(TestCase):
|
||||
x = torch.save(torch.nn.Linear(2, 3), checkpoint)
|
||||
self.assertEquals(len(warns), 0)
|
||||
|
||||
|
||||
# unique_key is necessary because on Python 2.7, if a warning passed to
|
||||
# the warning module is the same, it is not raised again.
|
||||
def _test_serialization_container(self, unique_key, filecontext_lambda):
|
||||
|
||||
tmpmodule_name = 'tmpmodule{}'.format(unique_key)
|
||||
|
||||
def import_module(name, filename):
|
||||
if sys.version_info >= (3, 5):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(name, filename)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
else:
|
||||
import imp
|
||||
module = imp.load_source(name, filename)
|
||||
sys.modules[module.__name__] = module
|
||||
return module
|
||||
|
||||
with filecontext_lambda() as checkpoint:
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
'_internal', 'data', 'network1.py')
|
||||
module = import_module(tmpmodule_name, fname)
|
||||
torch.save(module.Net(), checkpoint)
|
||||
|
||||
# First check that the checkpoint can be loaded without warnings
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEquals(len(w), 0)
|
||||
|
||||
# Replace the module with different source
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
'_internal', 'data', 'network2.py')
|
||||
module = import_module(tmpmodule_name, fname)
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEquals(len(w), 1)
|
||||
self.assertTrue(w[0].category, 'SourceChangeWarning')
|
||||
|
||||
def test_serialization_container(self):
|
||||
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
|
||||
|
||||
def test_serialization_container_filelike(self):
|
||||
self._test_serialization_container('filelike', BytesIOContext)
|
||||
|
||||
def test_serialization_map_location(self):
|
||||
test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt')
|
||||
|
||||
@ -630,6 +517,7 @@ class TestSerialization(TestCase):
|
||||
def test_load_python2_unicode_module(self):
|
||||
# This Pickle contains some Unicode data!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
|
||||
def test_load_error_msg(self):
|
||||
@ -643,5 +531,150 @@ class TestSerialization(TestCase):
|
||||
with self.assertRaisesRegex(AttributeError, expected_err_msg):
|
||||
torch.load(resource)
|
||||
|
||||
|
||||
class serialization_method(object):
|
||||
def __init__(self, use_zip):
|
||||
self.use_zip = use_zip
|
||||
self.torch_save = torch.save
|
||||
|
||||
def __enter__(self, *args, **kwargs):
|
||||
def wrapper(*args, **kwargs):
|
||||
if '_use_new_zipfile_serialization' in kwargs:
|
||||
raise RuntimeError("Cannot set method manually")
|
||||
kwargs['_use_new_zipfile_serialization'] = self.use_zip
|
||||
return self.torch_save(*args, **kwargs)
|
||||
|
||||
torch.save = wrapper
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
torch.save = self.torch_save
|
||||
|
||||
|
||||
class TestOldSerialization(TestCase, SerializationMixin):
|
||||
# unique_key is necessary because on Python 2.7, if a warning passed to
|
||||
# the warning module is the same, it is not raised again.
|
||||
def _test_serialization_container(self, unique_key, filecontext_lambda):
|
||||
|
||||
tmpmodule_name = 'tmpmodule{}'.format(unique_key)
|
||||
|
||||
def import_module(name, filename):
|
||||
if sys.version_info >= (3, 5):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(name, filename)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
else:
|
||||
import imp
|
||||
module = imp.load_source(name, filename)
|
||||
sys.modules[module.__name__] = module
|
||||
return module
|
||||
|
||||
with filecontext_lambda() as checkpoint:
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
'_internal', 'data', 'network1.py')
|
||||
module = import_module(tmpmodule_name, fname)
|
||||
torch.save(module.Net(), checkpoint)
|
||||
|
||||
# First check that the checkpoint can be loaded without warnings
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEquals(len(w), 0)
|
||||
|
||||
# Replace the module with different source
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
'_internal', 'data', 'network2.py')
|
||||
module = import_module(tmpmodule_name, fname)
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEquals(len(w), 1)
|
||||
self.assertTrue(w[0].category, 'SourceChangeWarning')
|
||||
|
||||
def test_serialization_container(self):
|
||||
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
|
||||
|
||||
def test_serialization_container_filelike(self):
|
||||
self._test_serialization_container('filelike', BytesIOContext)
|
||||
|
||||
def test_serialization_offset(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
|
||||
m = torch.nn.Conv2d(1, 1, (1, 3))
|
||||
i, j = 41, 43
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
pickle.dump(j, f)
|
||||
torch.save(b, f)
|
||||
torch.save(m, f)
|
||||
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
|
||||
f.seek(0)
|
||||
i_loaded = pickle.load(f)
|
||||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
m_loaded = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
def test_serialization_offset_filelike(self):
|
||||
a = torch.randn(5, 5)
|
||||
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
|
||||
i, j = 41, 43
|
||||
with BytesIOContext() as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
pickle.dump(j, f)
|
||||
torch.save(b, f)
|
||||
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
|
||||
f.seek(0)
|
||||
i_loaded = pickle.load(f)
|
||||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertEqual(i, i_loaded)
|
||||
self.assertEqual(j, j_loaded)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=False):
|
||||
return super(TestOldSerialization, self).run(*args, **kwargs)
|
||||
|
||||
|
||||
class TestSerialization(TestCase, SerializationMixin):
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_serialization_zipfile(self):
|
||||
data = self._test_serialization_data()
|
||||
|
||||
def test(name_or_buffer):
|
||||
torch.save(data, name_or_buffer)
|
||||
|
||||
if hasattr(name_or_buffer, 'seek'):
|
||||
name_or_buffer.seek(0)
|
||||
|
||||
result = torch.load(name_or_buffer)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
test(f)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
test(f.name)
|
||||
|
||||
test(io.BytesIO())
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super(TestSerialization, self).run(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -11,8 +11,8 @@ import torch.backends.cuda
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
import pickle
|
||||
import types
|
||||
import pickle
|
||||
import textwrap
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
from torch._six import inf, nan, string_classes, istuple
|
||||
|
2
third_party/miniz-2.0.8/miniz.c
vendored
2
third_party/miniz-2.0.8/miniz.c
vendored
@ -4522,7 +4522,9 @@ void *mz_zip_reader_extract_file_to_heap(mz_zip_archive *pZip, const char *pFile
|
||||
mz_bool mz_zip_reader_extract_to_callback(mz_zip_archive *pZip, mz_uint file_index, mz_file_write_func pCallback, void *pOpaque, mz_uint flags)
|
||||
{
|
||||
int status = TINFL_STATUS_DONE;
|
||||
#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS
|
||||
mz_uint file_crc32 = MZ_CRC32_INIT;
|
||||
#endif
|
||||
mz_uint64 read_buf_size, read_buf_ofs = 0, read_buf_avail, comp_remaining, out_buf_ofs = 0, cur_file_ofs;
|
||||
mz_zip_archive_file_stat file_stat;
|
||||
void *pRead_buf = NULL;
|
||||
|
@ -215,16 +215,17 @@ PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *args)
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject *file = PyTuple_GET_ITEM(args, 0);
|
||||
bool is_real_file = PyTuple_GET_ITEM(args, 1) == Py_True;
|
||||
bool save_size = PyTuple_GET_ITEM(args, 2) == Py_True;
|
||||
|
||||
if (!is_real_file) {
|
||||
THPStorage_(writeFileRaw<PyObject*>)(self->cdata, file);
|
||||
THPStorage_(writeFileRaw<PyObject*>)(self->cdata, file, save_size);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
int fd = PyObject_AsFileDescriptor(file);
|
||||
THPUtils_assert(fd != -1, "_write_file couldn't retrieve a file descriptor "
|
||||
"from given object");
|
||||
THPStorage_(writeFileRaw)(self->cdata, fd);
|
||||
THPStorage_(writeFileRaw)(self->cdata, fd, save_size);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
@ -6,8 +6,11 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#endif
|
||||
|
||||
// save_save is necessary since the old eager format saved storages as
|
||||
// [size + data], but the v1.5 eager format removes this since size is saved in
|
||||
// the filesize.
|
||||
template <class io>
|
||||
void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
|
||||
void THPStorage_(writeFileRaw)(THWStorage *self, io fd, bool save_size)
|
||||
{
|
||||
#ifdef THC_GENERIC_FILE
|
||||
c10::cuda::CUDAGuard guard(self->device());
|
||||
@ -22,6 +25,7 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
|
||||
data = (scalar_t*)cpu_data.get();
|
||||
THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
|
||||
#endif
|
||||
if (save_size) {
|
||||
if (torch::utils::THP_nativeByteOrder() ==
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
|
||||
doWrite(fd, &size, sizeof(int64_t));
|
||||
@ -34,6 +38,7 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
|
||||
1);
|
||||
doWrite(fd, &nsize, sizeof(int64_t));
|
||||
}
|
||||
}
|
||||
// fast track for bytes and little endian
|
||||
if (sizeof(scalar_t) == 1 ||
|
||||
torch::utils::THP_nativeByteOrder() ==
|
||||
@ -68,8 +73,8 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
|
||||
}
|
||||
}
|
||||
|
||||
template void THPStorage_(writeFileRaw<int>)(THWStorage *self, int fd);
|
||||
template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* fd);
|
||||
template void THPStorage_(writeFileRaw<int>)(THWStorage *self, int fd, bool save_size);
|
||||
template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* fd, bool save_size);
|
||||
|
||||
template <class io>
|
||||
THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)
|
||||
|
@ -3,7 +3,7 @@
|
||||
#else
|
||||
|
||||
template <class io>
|
||||
void THPStorage_(writeFileRaw)(THWStorage *self, io fd);
|
||||
void THPStorage_(writeFileRaw)(THWStorage *self, io fd, bool save_size);
|
||||
|
||||
template <class io>
|
||||
THWStorage * THPStorage_(readFileRaw)(io fd, THWStorage *storage);
|
||||
|
@ -486,19 +486,50 @@ void initJITBindings(PyObject* module) {
|
||||
BufferAdapter(const py::object& buffer) : buffer_(buffer) {
|
||||
// Jump to the end of the buffer to get its size
|
||||
auto current = buffer.attr("tell")();
|
||||
start_offset_ = py::cast<size_t>(current);
|
||||
buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
|
||||
size_ = py::cast<size_t>(buffer.attr("tell")());
|
||||
size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
|
||||
buffer.attr("seek")(current);
|
||||
|
||||
// If we can read directly into a buffer, do that instead of an extra copy
|
||||
use_readinto_ = py::hasattr(buffer, "readinto");
|
||||
}
|
||||
|
||||
size_t size() const override {
|
||||
return size_;
|
||||
}
|
||||
|
||||
THPObjectPtr getMemview(void* buf, size_t n) const {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
THPObjectPtr memview(PyMemoryView_FromMemory(
|
||||
reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
|
||||
#else
|
||||
THPObjectPtr memview(PyBuffer_FromReadWriteMemory(buf, n));
|
||||
#endif
|
||||
if (!memview) {
|
||||
throw python_error();
|
||||
}
|
||||
return memview;
|
||||
}
|
||||
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
|
||||
const override {
|
||||
// Seek to desired position
|
||||
buffer_.attr("seek")(pos);
|
||||
// Seek to desired position (NB: this has to be a Py_ssize_t or Python
|
||||
// throws a weird error)
|
||||
Py_ssize_t absolute_pos = start_offset_ + pos;
|
||||
buffer_.attr("seek")(absolute_pos);
|
||||
|
||||
if (use_readinto_) {
|
||||
auto memview = getMemview(buf, n);
|
||||
auto res =
|
||||
PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
|
||||
if (res) {
|
||||
int i = PyInt_AsLong(res);
|
||||
if (i > 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read bytes into `buf` from the buffer
|
||||
std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
|
||||
@ -511,6 +542,8 @@ void initJITBindings(PyObject* module) {
|
||||
|
||||
py::object buffer_;
|
||||
size_t size_;
|
||||
size_t start_offset_;
|
||||
bool use_readinto_;
|
||||
};
|
||||
|
||||
py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")
|
||||
|
@ -49,13 +49,13 @@ def _is_zipfile(f):
|
||||
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
||||
# binary. Since we expect the files here to be generated by torch.save or
|
||||
# torch.jit.save, it's safe to only check the start bytes and avoid
|
||||
# collisions. See bugs.python.org/issue28494.
|
||||
# collisions and assume the zip has only 1 file.
|
||||
# See bugs.python.org/issue28494.
|
||||
|
||||
# Read the first 4 bytes of the file
|
||||
read_bytes = []
|
||||
start = f.tell()
|
||||
|
||||
f.seek(0)
|
||||
byte = f.read(1)
|
||||
while byte != "":
|
||||
read_bytes.append(byte)
|
||||
@ -64,21 +64,8 @@ def _is_zipfile(f):
|
||||
byte = f.read(1)
|
||||
f.seek(start)
|
||||
|
||||
# zip magic numbers
|
||||
magic_numbers = [
|
||||
['P', 'K', '\x03', '\x04'],
|
||||
['P', 'K', '\x05', '\x06'],
|
||||
['P', 'K', '\x07', '\x08'],
|
||||
]
|
||||
for magic_number in magic_numbers:
|
||||
match = True
|
||||
for magic_byte, read_byte in zip(magic_number, read_bytes):
|
||||
if ord(magic_byte) != ord(read_byte):
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
return True
|
||||
return False
|
||||
local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
|
||||
return read_bytes == local_header_magic_number
|
||||
|
||||
|
||||
def register_package(priority, tagger, deserializer):
|
||||
@ -458,7 +445,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
|
||||
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
|
||||
f.flush()
|
||||
for key in serialized_storage_keys:
|
||||
serialized_storages[key]._write_file(f, _should_read_directly(f))
|
||||
serialized_storages[key]._write_file(f, _should_read_directly(f), True)
|
||||
|
||||
|
||||
def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
@ -495,8 +482,17 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
for key in sorted(serialized_storages.keys()):
|
||||
name = 'data/{}'.format(key)
|
||||
storage = serialized_storages[key]
|
||||
if storage.device.type == 'cpu':
|
||||
# If it's on the CPU we can directly copy it into the zip file
|
||||
num_bytes = storage.size() * storage.element_size()
|
||||
buf = io.BytesIO()
|
||||
zip_file.write_record(name, storage.data_ptr(), num_bytes)
|
||||
else:
|
||||
# Copy to a buffer, then serialize that
|
||||
buf = io.BytesIO()
|
||||
storage._write_file(buf, _should_read_directly(buf))
|
||||
buf_value = buf.getvalue()
|
||||
zip_file.write_record(name, buf_value, len(buf_value))
|
||||
|
||||
|
||||
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from ._utils import _type, _cuda
|
||||
@ -30,6 +31,7 @@ class _StorageBase(object):
|
||||
return new_storage
|
||||
|
||||
def __reduce__(self):
|
||||
warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
|
||||
b = io.BytesIO()
|
||||
torch.save(self, b)
|
||||
return (_load_from_bytes, (b.getvalue(),))
|
||||
|
Reference in New Issue
Block a user