Throw an exception in the constructor of torchscript serialization to avoid double-exception (#44266)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44266

If PyTorchStreamWriter is writing to a file in a non-existing path, it throws an exception. In unwinding the destructor calls writeEndOfFile() and throws again. To avoid this double-exception, a check and throw is added in the constructor. In such case the destructor will not be called and the exception can go through the unwinding.

Test Plan: python test/test_jit.py TestSaveLoad.test_save_nonexit_file

Reviewed By: dreiss

Differential Revision: D23560770

Pulled By: iseeyuan

fbshipit-source-id: 51b24403500bdab3578c7fd5e017780467a5d06a
This commit is contained in:
Martin Yuan
2020-10-28 22:36:23 -07:00
committed by Facebook GitHub Bot
parent 9c1a41b724
commit b553c06abb
2 changed files with 10 additions and 0 deletions

View File

@ -306,6 +306,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
valid("opening archive ", file_name.c_str());
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
file_stream_.write(static_cast<const char*>(buf), nbytes);
return !file_stream_ ? 0 : nbytes;

View File

@ -938,3 +938,12 @@ class TestSaveLoad(JitTestCase):
x = torch.tensor([1., 2., 3., 4.])
self.assertTrue(torch.equal(m(x), m2(x)))
def test_save_nonexit_file(self):
class Foo(torch.nn.Module):
def forward(self, x):
return 2 * x
script_module = torch.jit.script(Foo())
with self.assertRaises(RuntimeError):
script_module.save("NonExist/path/test.pt")