mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix PyTorchStreamWriter
exception handling (#88128)
Avoid double exception in destructor if attempting to serialize to python object that does not have `write` method Use `Finalizer` class in `PyTorchStreamWriter::writeEndOfFile()` to a always set `finailized_` property even if excretion occurs. (as there isn't much one can do at this point) Add expicit check for the attribue to `_open_zipfile_writer_buffer` and add unitests Modernize code a bit by using Python-3 `super()` method Fixes https://github.com/pytorch/pytorch/issues/87997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88128 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea8a5b09a9
commit
caaf37a111
@ -338,8 +338,7 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
// NOLINTNEXTLINE(modernize-pass-by-value)
|
||||
const std::function<size_t(const void*, size_t)>& writer_func)
|
||||
const std::function<size_t(const void*, size_t)> writer_func)
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func) {
|
||||
setup(archive_name_);
|
||||
@ -416,6 +415,21 @@ void PyTorchStreamWriter::writeRecord(
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
// Ensurers that finalized is set to true even
|
||||
// exception is raised during the method call.
|
||||
// I.e. even partial call to writeEndOfFile() should mark
|
||||
// file as finalized, otherwise double exception raised from
|
||||
// destructor would would result in `std::terminate()`
|
||||
// See https://github.com/pytorch/pytorch/issues/87997/
|
||||
struct Finalizer {
|
||||
Finalizer(bool& var): var_(var) {}
|
||||
~Finalizer() {
|
||||
var_ = true;
|
||||
}
|
||||
private:
|
||||
bool& var_;
|
||||
} f(finalized_);
|
||||
|
||||
auto allRecords = getAllWrittenRecords();
|
||||
// If no ".data/version" or "version" record in the output model, rewrites version info
|
||||
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
|
||||
|
@ -130,7 +130,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
public:
|
||||
explicit PyTorchStreamWriter(std::string archive_name);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
const std::function<size_t(const void*, size_t)> writer_func);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
|
@ -567,6 +567,25 @@ class SerializationMixin(object):
|
||||
b = torch.load(data)
|
||||
self.assertTrue(data.was_called('readinto'))
|
||||
|
||||
def test_serialization_filelike_exceptions(self):
|
||||
# Try to serialize to buffers that does not have write method
|
||||
# Or have a malfrormed one, and make sure it does not cause an abort
|
||||
# See https://github.com/pytorch/pytorch/issues/87997
|
||||
x = torch.rand(10)
|
||||
with self.assertRaises(AttributeError):
|
||||
# Tries to serialize str into tensor
|
||||
torch.save('foo', x)
|
||||
x.write = "bar"
|
||||
x.flush = "baz"
|
||||
with self.assertRaises(TypeError):
|
||||
# Tries to serialize str into tensor with write property
|
||||
torch.save('foo', x)
|
||||
x.write = str.__add__
|
||||
x.flush = str.__mul__
|
||||
with self.assertRaises(TypeError):
|
||||
# Tries to serialize str into tensor with wrong callable write property
|
||||
torch.save('foo', x)
|
||||
|
||||
|
||||
def test_serialization_storage_slice(self):
|
||||
# Generated using:
|
||||
|
@ -1253,7 +1253,7 @@ void initJITBindings(PyObject* module) {
|
||||
.def(py::init<std::string>())
|
||||
.def(py::init([](const py::object& buffer) {
|
||||
auto writer_func = [=](const void* data, size_t size) {
|
||||
// Writting an empty file is a noop
|
||||
// Writing an empty file is a noop
|
||||
if (size == 0) {
|
||||
return size;
|
||||
}
|
||||
|
@ -248,7 +248,7 @@ class _opener(object):
|
||||
|
||||
class _open_file(_opener):
|
||||
def __init__(self, name, mode):
|
||||
super(_open_file, self).__init__(open(name, mode))
|
||||
super().__init__(open(name, mode))
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.file_like.close()
|
||||
@ -256,7 +256,7 @@ class _open_file(_opener):
|
||||
|
||||
class _open_buffer_reader(_opener):
|
||||
def __init__(self, buffer):
|
||||
super(_open_buffer_reader, self).__init__(buffer)
|
||||
super().__init__(buffer)
|
||||
_check_seekable(buffer)
|
||||
|
||||
|
||||
@ -279,12 +279,12 @@ def _open_file_like(name_or_buffer, mode):
|
||||
|
||||
class _open_zipfile_reader(_opener):
|
||||
def __init__(self, name_or_buffer) -> None:
|
||||
super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||
|
||||
|
||||
class _open_zipfile_writer_file(_opener):
|
||||
def __init__(self, name) -> None:
|
||||
super(_open_zipfile_writer_file, self).__init__(torch._C.PyTorchFileWriter(str(name)))
|
||||
super().__init__(torch._C.PyTorchFileWriter(str(name)))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -292,8 +292,13 @@ class _open_zipfile_writer_file(_opener):
|
||||
|
||||
class _open_zipfile_writer_buffer(_opener):
|
||||
def __init__(self, buffer) -> None:
|
||||
if not callable(getattr(buffer, "write", None)):
|
||||
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
|
||||
if not hasattr(buffer, "write"):
|
||||
raise AttributeError(msg)
|
||||
raise TypeError(msg)
|
||||
self.buffer = buffer
|
||||
super(_open_zipfile_writer_buffer, self).__init__(torch._C.PyTorchFileWriter(buffer))
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
|
Reference in New Issue
Block a user