[torch.futures] Fix nullptr deref (#85304)

`torch.jit.wait(None)` and `torch.futures.collect_all((None,))` should not crash.

Fixes https://github.com/pytorch/pytorch/issues/85237

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85304
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga
2022-09-20 01:49:04 +00:00
committed by PyTorch MergeBot
parent 66907e7262
commit c05ca0dbf2
2 changed files with 14 additions and 2 deletions

View File

@ -2127,6 +2127,7 @@ void initJITBindings(PyObject* module) {
});
m.def("wait", [](const std::shared_ptr<PythonFutureWrapper>& fut) {
TORCH_CHECK(fut, "Future can't be None");
return fut->wait();
});
@ -2134,12 +2135,14 @@ void initJITBindings(PyObject* module) {
"_collect_all",
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
-> std::shared_ptr<jit::PythonFutureWrapper> {
auto typePtr =
futures.empty() ? AnyType::get() : futures[0]->fut->elementType();
auto typePtr = futures.empty() || futures[0] == nullptr
? AnyType::get()
: futures[0]->fut->elementType();
c10::List<c10::intrusive_ptr<c10::ivalue::Future>> asList(
c10::FutureType::create(typePtr));
asList.reserve(futures.size());
for (const auto& f : futures) {
TORCH_CHECK(f, "Future can't be None");
asList.push_back(f->fut);
}
return std::make_shared<jit::PythonFutureWrapper>(