Fix some bugs with zipfile serialization (#32244)

Summary:
Stacked PRs
 * #32958 - Make zip serialization the default
 * **#32244 - Fix some bugs with zipfile serialization**

It includes the following changes:
* Split up tests so that we can test both serialization methods
    * Loading something within a buffer doesn't work anymore, so those tests are only on the old serialization method (it's possible but introduces a big slowdown since it requires a linear scan of the entire zipfile to find the magic number at the end)
* Call `readinto` on a buffer if possible instead of `read` + a copy
* Disable CRC-32 checks on read (there was some issue where miniz said the CRC was wrong but `zipinfo` and `unzip` said the zip file was fine)
](https://our.intern.facebook.com/intern/diff/19418935/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32244

Pulled By: driazati

Reviewed By: eellison

Differential Revision: D19418935

fbshipit-source-id: df140854f52ecd04236225417d625374fd99f573
This commit is contained in:
davidriazati
2020-02-05 15:30:21 -08:00
committed by Facebook Github Bot
parent ab75d64e6e
commit 74ce3a032c
11 changed files with 238 additions and 166 deletions

View File

@ -1442,6 +1442,8 @@ if (NOT INTERN_BUILD_MOBILE)
ENDIF(HAVE_MALLOC_USABLE_SIZE)
ENDIF(UNIX)
ADD_DEFINITIONS(-DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS)
# Is __thread supported?
IF(NOT MSVC)
CHECK_C_SOURCE_COMPILES("static __thread int x = 1; int main() { return x; }" C_HAS_THREAD)

View File

@ -4,6 +4,7 @@ import sys
import random
import string
import unittest
import io
try:
import unittest.mock as mock
except ImportError:
@ -5804,11 +5805,6 @@ class TestNN(NNTestCase):
@unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1")
def test_RNN_dropout_state(self):
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
for p in (0, 0.1234):
for train in (True, False):
for cuda in (True, False):
@ -5829,8 +5825,10 @@ class TestNN(NNTestCase):
output1, hy1 = rnn(input, hx)
output2, hy2 = rnn(input, hx)
rnn_pickle = pickle.dumps(rnn)
rnn2 = pickle.loads(rnn_pickle)
buf = io.BytesIO()
rnn_pickle = torch.save(rnn, buf)
buf.seek(0)
rnn2 = torch.load(buf)
rnn2.flatten_parameters()
output3, hy3 = rnn2(input, hx)

View File

@ -71,7 +71,7 @@ class FilelikeMock(object):
return name in self.calls
class TestSerialization(TestCase):
class SerializationMixin(object):
def _test_serialization_data(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)] # 0-3
@ -141,26 +141,6 @@ class TestSerialization(TestCase):
test(io.BytesIO())
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_zipfile(self):
data = self._test_serialization_data()
def test(name_or_buffer):
torch.save(data, name_or_buffer, _use_new_zipfile_serialization=True)
if hasattr(name_or_buffer, 'seek'):
name_or_buffer.seek(0)
result = torch.load(name_or_buffer)
self.assertEqual(result, data)
with tempfile.NamedTemporaryFile() as f:
test(f)
with tempfile.NamedTemporaryFile() as f:
test(f.name)
test(io.BytesIO())
def test_serialization(self):
# Test serialization with a real file
b = self._test_serialization_data()
@ -235,6 +215,7 @@ class TestSerialization(TestCase):
self.assertFalse(torch.serialization._is_zipfile(f))
self.assertEqual(torch.load(f.name), t)
@unittest.skipIf(not PY3, "gzip doesn't support os.seek(0, os.SEEK_END) on Python 2")
def test_serialization_gzip(self):
# Test serialization with gzip file
b = self._test_serialization_data()
@ -248,30 +229,6 @@ class TestSerialization(TestCase):
c = torch.load(f)
self._test_serialization_assert(b, c)
def test_serialization_offset(self):
a = torch.randn(5, 5)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
m = torch.nn.Conv2d(1, 1, (1, 3))
i, j = 41, 43
with tempfile.NamedTemporaryFile() as f:
pickle.dump(i, f)
torch.save(a, f)
pickle.dump(j, f)
torch.save(b, f)
torch.save(m, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
m_loaded = torch.load(f)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
self.assertEqual(i, i_loaded)
self.assertEqual(j, j_loaded)
@unittest.skipIf(
not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1,
'"dill" not found or is correct version'
@ -304,26 +261,7 @@ class TestSerialization(TestCase):
self.assertIsInstance(x3, type(x))
self.assertEqual(x, x3)
def test_serialization_offset_filelike(self):
a = torch.randn(5, 5)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
i, j = 41, 43
with BytesIOContext() as f:
pickle.dump(i, f)
torch.save(a, f)
pickle.dump(j, f)
torch.save(b, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertEqual(i, i_loaded)
self.assertEqual(j, j_loaded)
@unittest.skipIf(not PY3, "gzip doesn't support os.seek(0, os.SEEK_END) on Python 2")
def test_serialization_offset_gzip(self):
a = torch.randn(5, 5)
i = 41
@ -405,57 +343,6 @@ class TestSerialization(TestCase):
x = torch.save(torch.nn.Linear(2, 3), checkpoint)
self.assertEquals(len(warns), 0)
# unique_key is necessary because on Python 2.7, if a warning passed to
# the warning module is the same, it is not raised again.
def _test_serialization_container(self, unique_key, filecontext_lambda):
tmpmodule_name = 'tmpmodule{}'.format(unique_key)
def import_module(name, filename):
if sys.version_info >= (3, 5):
import importlib.util
spec = importlib.util.spec_from_file_location(name, filename)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
import imp
module = imp.load_source(name, filename)
sys.modules[module.__name__] = module
return module
with filecontext_lambda() as checkpoint:
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
'_internal', 'data', 'network1.py')
module = import_module(tmpmodule_name, fname)
torch.save(module.Net(), checkpoint)
# First check that the checkpoint can be loaded without warnings
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEquals(len(w), 0)
# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
'_internal', 'data', 'network2.py')
module = import_module(tmpmodule_name, fname)
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEquals(len(w), 1)
self.assertTrue(w[0].category, 'SourceChangeWarning')
def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
def test_serialization_container_filelike(self):
self._test_serialization_container('filelike', BytesIOContext)
def test_serialization_map_location(self):
test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt')
@ -630,7 +517,8 @@ class TestSerialization(TestCase):
def test_load_python2_unicode_module(self):
# This Pickle contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
self.assertIsNotNone(torch.load(path))
with warnings.catch_warnings(record=True) as w:
self.assertIsNotNone(torch.load(path))
def test_load_error_msg(self):
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
@ -643,5 +531,150 @@ class TestSerialization(TestCase):
with self.assertRaisesRegex(AttributeError, expected_err_msg):
torch.load(resource)
class serialization_method(object):
def __init__(self, use_zip):
self.use_zip = use_zip
self.torch_save = torch.save
def __enter__(self, *args, **kwargs):
def wrapper(*args, **kwargs):
if '_use_new_zipfile_serialization' in kwargs:
raise RuntimeError("Cannot set method manually")
kwargs['_use_new_zipfile_serialization'] = self.use_zip
return self.torch_save(*args, **kwargs)
torch.save = wrapper
def __exit__(self, *args, **kwargs):
torch.save = self.torch_save
class TestOldSerialization(TestCase, SerializationMixin):
# unique_key is necessary because on Python 2.7, if a warning passed to
# the warning module is the same, it is not raised again.
def _test_serialization_container(self, unique_key, filecontext_lambda):
tmpmodule_name = 'tmpmodule{}'.format(unique_key)
def import_module(name, filename):
if sys.version_info >= (3, 5):
import importlib.util
spec = importlib.util.spec_from_file_location(name, filename)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
import imp
module = imp.load_source(name, filename)
sys.modules[module.__name__] = module
return module
with filecontext_lambda() as checkpoint:
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
'_internal', 'data', 'network1.py')
module = import_module(tmpmodule_name, fname)
torch.save(module.Net(), checkpoint)
# First check that the checkpoint can be loaded without warnings
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEquals(len(w), 0)
# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
'_internal', 'data', 'network2.py')
module = import_module(tmpmodule_name, fname)
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEquals(len(w), 1)
self.assertTrue(w[0].category, 'SourceChangeWarning')
def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
def test_serialization_container_filelike(self):
self._test_serialization_container('filelike', BytesIOContext)
def test_serialization_offset(self):
a = torch.randn(5, 5)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
m = torch.nn.Conv2d(1, 1, (1, 3))
i, j = 41, 43
with tempfile.NamedTemporaryFile() as f:
pickle.dump(i, f)
torch.save(a, f)
pickle.dump(j, f)
torch.save(b, f)
torch.save(m, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
m_loaded = torch.load(f)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
self.assertEqual(i, i_loaded)
self.assertEqual(j, j_loaded)
def test_serialization_offset_filelike(self):
a = torch.randn(5, 5)
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
i, j = 41, 43
with BytesIOContext() as f:
pickle.dump(i, f)
torch.save(a, f)
pickle.dump(j, f)
torch.save(b, f)
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
f.seek(0)
i_loaded = pickle.load(f)
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertEqual(i, i_loaded)
self.assertEqual(j, j_loaded)
def run(self, *args, **kwargs):
with serialization_method(use_zip=False):
return super(TestOldSerialization, self).run(*args, **kwargs)
class TestSerialization(TestCase, SerializationMixin):
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_zipfile(self):
data = self._test_serialization_data()
def test(name_or_buffer):
torch.save(data, name_or_buffer)
if hasattr(name_or_buffer, 'seek'):
name_or_buffer.seek(0)
result = torch.load(name_or_buffer)
self.assertEqual(result, data)
with tempfile.NamedTemporaryFile() as f:
test(f)
with tempfile.NamedTemporaryFile() as f:
test(f.name)
test(io.BytesIO())
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super(TestSerialization, self).run(*args, **kwargs)
if __name__ == '__main__':
run_tests()

View File

@ -11,8 +11,8 @@ import torch.backends.cuda
import tempfile
import unittest
import warnings
import pickle
import types
import pickle
import textwrap
from torch.utils.dlpack import from_dlpack, to_dlpack
from torch._six import inf, nan, string_classes, istuple

View File

@ -4522,7 +4522,9 @@ void *mz_zip_reader_extract_file_to_heap(mz_zip_archive *pZip, const char *pFile
mz_bool mz_zip_reader_extract_to_callback(mz_zip_archive *pZip, mz_uint file_index, mz_file_write_func pCallback, void *pOpaque, mz_uint flags)
{
int status = TINFL_STATUS_DONE;
#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS
mz_uint file_crc32 = MZ_CRC32_INIT;
#endif
mz_uint64 read_buf_size, read_buf_ofs = 0, read_buf_avail, comp_remaining, out_buf_ofs = 0, cur_file_ofs;
mz_zip_archive_file_stat file_stat;
void *pRead_buf = NULL;

View File

@ -215,16 +215,17 @@ PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *args)
HANDLE_TH_ERRORS
PyObject *file = PyTuple_GET_ITEM(args, 0);
bool is_real_file = PyTuple_GET_ITEM(args, 1) == Py_True;
bool save_size = PyTuple_GET_ITEM(args, 2) == Py_True;
if (!is_real_file) {
THPStorage_(writeFileRaw<PyObject*>)(self->cdata, file);
THPStorage_(writeFileRaw<PyObject*>)(self->cdata, file, save_size);
Py_RETURN_NONE;
}
int fd = PyObject_AsFileDescriptor(file);
THPUtils_assert(fd != -1, "_write_file couldn't retrieve a file descriptor "
"from given object");
THPStorage_(writeFileRaw)(self->cdata, fd);
THPStorage_(writeFileRaw)(self->cdata, fd, save_size);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

View File

@ -6,8 +6,11 @@
#include <c10/cuda/CUDAGuard.h>
#endif
// save_save is necessary since the old eager format saved storages as
// [size + data], but the v1.5 eager format removes this since size is saved in
// the filesize.
template <class io>
void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
void THPStorage_(writeFileRaw)(THWStorage *self, io fd, bool save_size)
{
#ifdef THC_GENERIC_FILE
c10::cuda::CUDAGuard guard(self->device());
@ -22,17 +25,19 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
data = (scalar_t*)cpu_data.get();
THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
#endif
if (torch::utils::THP_nativeByteOrder() ==
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
doWrite(fd, &size, sizeof(int64_t));
else {
int64_t nsize; // convert big endian cpu to little endian storage
torch::utils::THP_encodeInt64Buffer(
(uint8_t*)&nsize,
(const int64_t*)&size,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
1);
doWrite(fd, &nsize, sizeof(int64_t));
if (save_size) {
if (torch::utils::THP_nativeByteOrder() ==
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
doWrite(fd, &size, sizeof(int64_t));
else {
int64_t nsize; // convert big endian cpu to little endian storage
torch::utils::THP_encodeInt64Buffer(
(uint8_t*)&nsize,
(const int64_t*)&size,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
1);
doWrite(fd, &nsize, sizeof(int64_t));
}
}
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 ||
@ -68,8 +73,8 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
}
}
template void THPStorage_(writeFileRaw<int>)(THWStorage *self, int fd);
template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* fd);
template void THPStorage_(writeFileRaw<int>)(THWStorage *self, int fd, bool save_size);
template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* fd, bool save_size);
template <class io>
THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)

View File

@ -3,7 +3,7 @@
#else
template <class io>
void THPStorage_(writeFileRaw)(THWStorage *self, io fd);
void THPStorage_(writeFileRaw)(THWStorage *self, io fd, bool save_size);
template <class io>
THWStorage * THPStorage_(readFileRaw)(io fd, THWStorage *storage);

View File

@ -486,19 +486,50 @@ void initJITBindings(PyObject* module) {
BufferAdapter(const py::object& buffer) : buffer_(buffer) {
// Jump to the end of the buffer to get its size
auto current = buffer.attr("tell")();
start_offset_ = py::cast<size_t>(current);
buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
size_ = py::cast<size_t>(buffer.attr("tell")());
size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
buffer.attr("seek")(current);
// If we can read directly into a buffer, do that instead of an extra copy
use_readinto_ = py::hasattr(buffer, "readinto");
}
size_t size() const override {
return size_;
}
THPObjectPtr getMemview(void* buf, size_t n) const {
#if PY_MAJOR_VERSION >= 3
THPObjectPtr memview(PyMemoryView_FromMemory(
reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
#else
THPObjectPtr memview(PyBuffer_FromReadWriteMemory(buf, n));
#endif
if (!memview) {
throw python_error();
}
return memview;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
const override {
// Seek to desired position
buffer_.attr("seek")(pos);
// Seek to desired position (NB: this has to be a Py_ssize_t or Python
// throws a weird error)
Py_ssize_t absolute_pos = start_offset_ + pos;
buffer_.attr("seek")(absolute_pos);
if (use_readinto_) {
auto memview = getMemview(buf, n);
auto res =
PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
if (res) {
int i = PyInt_AsLong(res);
if (i > 0) {
return i;
}
}
}
// Read bytes into `buf` from the buffer
std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
@ -511,6 +542,8 @@ void initJITBindings(PyObject* module) {
py::object buffer_;
size_t size_;
size_t start_offset_;
bool use_readinto_;
};
py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")

View File

@ -49,13 +49,13 @@ def _is_zipfile(f):
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or
# torch.jit.save, it's safe to only check the start bytes and avoid
# collisions. See bugs.python.org/issue28494.
# collisions and assume the zip has only 1 file.
# See bugs.python.org/issue28494.
# Read the first 4 bytes of the file
read_bytes = []
start = f.tell()
f.seek(0)
byte = f.read(1)
while byte != "":
read_bytes.append(byte)
@ -64,21 +64,8 @@ def _is_zipfile(f):
byte = f.read(1)
f.seek(start)
# zip magic numbers
magic_numbers = [
['P', 'K', '\x03', '\x04'],
['P', 'K', '\x05', '\x06'],
['P', 'K', '\x07', '\x08'],
]
for magic_number in magic_numbers:
match = True
for magic_byte, read_byte in zip(magic_number, read_bytes):
if ord(magic_byte) != ord(read_byte):
match = False
break
if match:
return True
return False
local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
return read_bytes == local_header_magic_number
def register_package(priority, tagger, deserializer):
@ -458,7 +445,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol):
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
serialized_storages[key]._write_file(f, _should_read_directly(f))
serialized_storages[key]._write_file(f, _should_read_directly(f), True)
def _save(obj, zip_file, pickle_module, pickle_protocol):
@ -495,8 +482,17 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
for key in sorted(serialized_storages.keys()):
name = 'data/{}'.format(key)
storage = serialized_storages[key]
num_bytes = storage.size() * storage.element_size()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
if storage.device.type == 'cpu':
# If it's on the CPU we can directly copy it into the zip file
num_bytes = storage.size() * storage.element_size()
buf = io.BytesIO()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
else:
# Copy to a buffer, then serialize that
buf = io.BytesIO()
storage._write_file(buf, _should_read_directly(buf))
buf_value = buf.getvalue()
zip_file.write_record(name, buf_value, len(buf_value))
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
@ -543,7 +539,7 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
:attr:`errors=...`.
.. warning::
:func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure.
:func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure.
It is possible to construct malicious pickle data which will execute arbitrary code
during unpickling. Never load data that could have come from an untrusted
source, or that could have been tampered with. **Only load data you trust**.

View File

@ -1,4 +1,5 @@
import io
import warnings
import torch
from ._utils import _type, _cuda
@ -30,6 +31,7 @@ class _StorageBase(object):
return new_storage
def __reduce__(self):
warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
b = io.BytesIO()
torch.save(self, b)
return (_load_from_bytes, (b.getvalue(),))