fix torch.futures docstring examples (#61029)

Summary:
Trying to run the doctests for the complete documentation hangs if it reaches the examples of `torch.futures`. It turns out to be only syntax errors, which are normally just reported. My guess is that `doctest` probably doesn't work well for failures within async stuff.

Anyway, while debugging this, I fixed the syntax.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61029

Reviewed By: mruberry

Differential Revision: D29571923

Pulled By: mrshenli

fbshipit-source-id: bb8112be5302c6ec43151590b438b195a8f30a06
This commit is contained in:
Philip Meier
2021-07-07 11:46:02 -07:00
committed by Facebook GitHub Bot
parent 376dc500a9
commit 1262b2c4c6
2 changed files with 22 additions and 43 deletions

View File

@ -367,6 +367,7 @@ import sphinx.ext.doctest
doctest_test_doctest_blocks = ''
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
doctest_global_setup = '''
import torch
try:
import torchvision
except ImportError:

View File

@ -147,23 +147,18 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
on those futures independently.
Example::
>>> import torch
>>>
>>> def callback(fut):
>>> print(f"RPC return value is {fut.wait()}.")
>>>
... print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
>>> lambda x : print(f"Chained cb done. {x.wait()}")
>>> )
... lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
>>>
>>> # Outputs are:
>>> # RPC return value is 5.
>>> # Chained cb done. None
RPC return value is 5.
Chained cb done. None
"""
return cast(Future[S], super().then(callback))
@ -200,19 +195,14 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
for handling completion/waiting on those futures independently.
Example::
>>> import torch
>>>
>>> def callback(fut):
>>> print(f"This will run after the future has finished.")
>>> print(fut.wait())
>>>
... print(f"This will run after the future has finished.")
... print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
>>>
>>> # Outputs are:
>>> This will run after the future has finished.
>>> 5
This will run after the future has finished.
5
"""
super().add_done_callback(callback)
@ -239,20 +229,17 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
Example::
>>> import threading
>>> import time
>>> import torch
>>>
>>> def slow_set_future(fut, value):
>>> time.sleep(0.5)
>>> fut.set_result(value)
>>>
... time.sleep(0.5)
... fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
>>> target=slow_set_future,
>>> args=(fut, torch.ones(2) * 3)
>>> )
... target=slow_set_future,
... args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>>
>>> print(fut.wait()) # tensor([3., 3.])
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()
"""
super().set_result(result)
@ -268,15 +255,12 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
result (BaseException): the exception for this ``Future``.
Example::
>>> import torch
>>>
>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
>>>
>>> # Output:
>>> # This will run after the future has finished.
>>> ValueError: foo
Traceback (most recent call last):
...
ValueError: foo
"""
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
@ -301,22 +285,16 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]:
in Futures.
Example::
>>> import torch
>>>
>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>>
>>> fut = torch.futures.collect_all([fut0, fut1])
>>>
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>>
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
>>> # outputs:
>>> # fut0 result = 0
>>> # fut1 result = 1
fut1 result = 1
"""
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))