Fix PyTorchStreamWriter exception handling (#88128)

Avoid double exception in destructor if attempting to serialize to
python object that does not have `write` method

Use `Finalizer` class in `PyTorchStreamWriter::writeEndOfFile()` to a
always set `finailized_` property even if excretion occurs. (as there
isn't much one can do at this point)

Add expicit check for the attribue to `_open_zipfile_writer_buffer` and
add unitests

Modernize code a bit by using Python-3 `super()` method

Fixes https://github.com/pytorch/pytorch/issues/87997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88128
Approved by: https://github.com/albanD
This commit is contained in:
Nikita Shulga
2022-10-31 23:38:03 +00:00
committed by PyTorch MergeBot
parent ea8a5b09a9
commit caaf37a111
5 changed files with 47 additions and 9 deletions

View File

@ -338,8 +338,7 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
}
PyTorchStreamWriter::PyTorchStreamWriter(
// NOLINTNEXTLINE(modernize-pass-by-value)
const std::function<size_t(const void*, size_t)>& writer_func)
const std::function<size_t(const void*, size_t)> writer_func)
: archive_name_("archive"),
writer_func_(writer_func) {
setup(archive_name_);
@ -416,6 +415,21 @@ void PyTorchStreamWriter::writeRecord(
}
void PyTorchStreamWriter::writeEndOfFile() {
// Ensurers that finalized is set to true even
// exception is raised during the method call.
// I.e. even partial call to writeEndOfFile() should mark
// file as finalized, otherwise double exception raised from
// destructor would would result in `std::terminate()`
// See https://github.com/pytorch/pytorch/issues/87997/
struct Finalizer {
Finalizer(bool& var): var_(var) {}
~Finalizer() {
var_ = true;
}
private:
bool& var_;
} f(finalized_);
auto allRecords = getAllWrittenRecords();
// If no ".data/version" or "version" record in the output model, rewrites version info
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {

View File

@ -130,7 +130,7 @@ class TORCH_API PyTorchStreamWriter final {
public:
explicit PyTorchStreamWriter(std::string archive_name);
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)>& writer_func);
const std::function<size_t(const void*, size_t)> writer_func);
void setMinVersion(const uint64_t version);

View File

@ -567,6 +567,25 @@ class SerializationMixin(object):
b = torch.load(data)
self.assertTrue(data.was_called('readinto'))
def test_serialization_filelike_exceptions(self):
# Try to serialize to buffers that does not have write method
# Or have a malfrormed one, and make sure it does not cause an abort
# See https://github.com/pytorch/pytorch/issues/87997
x = torch.rand(10)
with self.assertRaises(AttributeError):
# Tries to serialize str into tensor
torch.save('foo', x)
x.write = "bar"
x.flush = "baz"
with self.assertRaises(TypeError):
# Tries to serialize str into tensor with write property
torch.save('foo', x)
x.write = str.__add__
x.flush = str.__mul__
with self.assertRaises(TypeError):
# Tries to serialize str into tensor with wrong callable write property
torch.save('foo', x)
def test_serialization_storage_slice(self):
# Generated using:

View File

@ -1253,7 +1253,7 @@ void initJITBindings(PyObject* module) {
.def(py::init<std::string>())
.def(py::init([](const py::object& buffer) {
auto writer_func = [=](const void* data, size_t size) {
// Writting an empty file is a noop
// Writing an empty file is a noop
if (size == 0) {
return size;
}

View File

@ -248,7 +248,7 @@ class _opener(object):
class _open_file(_opener):
def __init__(self, name, mode):
super(_open_file, self).__init__(open(name, mode))
super().__init__(open(name, mode))
def __exit__(self, *args):
self.file_like.close()
@ -256,7 +256,7 @@ class _open_file(_opener):
class _open_buffer_reader(_opener):
def __init__(self, buffer):
super(_open_buffer_reader, self).__init__(buffer)
super().__init__(buffer)
_check_seekable(buffer)
@ -279,12 +279,12 @@ def _open_file_like(name_or_buffer, mode):
class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer) -> None:
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
def __init__(self, name) -> None:
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
super().__init__(torch._C.PyTorchFileWriter(str(name)))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -292,8 +292,13 @@ class _open_zipfile_writer_file(_opener):
class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer) -> None:
if not callable(getattr(buffer, "write", None)):
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
if not hasattr(buffer, "write"):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
super().__init__(torch._C.PyTorchFileWriter(buffer))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()