[torch.futures] Fix nullptr deref (#85304)

`torch.jit.wait(None)` and `torch.futures.collect_all((None,))` should not crash.

Fixes https://github.com/pytorch/pytorch/issues/85237

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85304
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga
2022-09-20 01:49:04 +00:00
committed by PyTorch MergeBot
parent 66907e7262
commit c05ca0dbf2
2 changed files with 14 additions and 2 deletions

View File

@ -327,5 +327,14 @@ class TestFuture(TestCase):
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])
def test_wait_none(self):
fut1 = Future[int]()
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.jit.wait(None)
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.futures.wait_all((None,)) # type: ignore[arg-type]
with self.assertRaisesRegex(RuntimeError, "Future can't be None"):
torch.futures.collect_all((fut1, None,)) # type: ignore[arg-type]
if __name__ == '__main__':
run_tests()