mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Throw an exception in the constructor of torchscript serialization to avoid double-exception (#44266)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44266 If PyTorchStreamWriter is writing to a file in a non-existing path, it throws an exception. In unwinding the destructor calls writeEndOfFile() and throws again. To avoid this double-exception, a check and throw is added in the constructor. In such case the destructor will not be called and the exception can go through the unwinding. Test Plan: python test/test_jit.py TestSaveLoad.test_save_nonexit_file Reviewed By: dreiss Differential Revision: D23560770 Pulled By: iseeyuan fbshipit-source-id: 51b24403500bdab3578c7fd5e017780467a5d06a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9c1a41b724
commit
b553c06abb
@ -306,6 +306,7 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||
valid("opening archive ", file_name.c_str());
|
||||
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
|
||||
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||
file_stream_.write(static_cast<const char*>(buf), nbytes);
|
||||
return !file_stream_ ? 0 : nbytes;
|
||||
|
@ -938,3 +938,12 @@ class TestSaveLoad(JitTestCase):
|
||||
|
||||
x = torch.tensor([1., 2., 3., 4.])
|
||||
self.assertTrue(torch.equal(m(x), m2(x)))
|
||||
|
||||
def test_save_nonexit_file(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return 2 * x
|
||||
|
||||
script_module = torch.jit.script(Foo())
|
||||
with self.assertRaises(RuntimeError):
|
||||
script_module.save("NonExist/path/test.pt")
|
||||
|
Reference in New Issue
Block a user