mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch.futures] Fix nullptr deref (#85304)
`torch.jit.wait(None)` and `torch.futures.collect_all((None,))` should not crash. Fixes https://github.com/pytorch/pytorch/issues/85237 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85304 Approved by: https://github.com/kit1980
This commit is contained in:
committed by
PyTorch MergeBot
parent
66907e7262
commit
c05ca0dbf2
@ -327,5 +327,14 @@ class TestFuture(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected error"):
|
||||
torch.futures.wait_all([fut3, fut2])
|
||||
|
||||
def test_wait_none(self):
|
||||
fut1 = Future[int]()
|
||||
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
|
||||
torch.jit.wait(None)
|
||||
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
|
||||
torch.futures.wait_all((None,)) # type: ignore[arg-type]
|
||||
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
|
||||
torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type]
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -2127,6 +2127,7 @@ void initJITBindings(PyObject* module) {
|
||||
});
|
||||
|
||||
m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
|
||||
TORCH_CHECK(fut, "Future can't be None");
|
||||
return fut->wait();
|
||||
});
|
||||
|
||||
@ -2134,12 +2135,14 @@ void initJITBindings(PyObject* module) {
|
||||
"_collect_all",
|
||||
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
|
||||
-> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
auto typePtr =
|
||||
futures.empty() ? AnyType::get() : futures[0]->fut->elementType();
|
||||
auto typePtr = futures.empty() || futures[0] == nullptr
|
||||
? AnyType::get()
|
||||
: futures[0]->fut->elementType();
|
||||
c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
|
||||
c10::FutureType::create(typePtr));
|
||||
asList.reserve(futures.size());
|
||||
for (const auto& f : futures) {
|
||||
TORCH_CHECK(f, "Future can't be None");
|
||||
asList.push_back(f->fut);
|
||||
}
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
|
Reference in New Issue
Block a user