mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix torch.futures
docstring examples (#61029)
Summary: Trying to run the doctests for the complete documentation hangs if it reaches the examples of `torch.futures`. It turns out to be only syntax errors, which are normally just reported. My guess is that `doctest` probably doesn't work well for failures within async stuff. Anyway, while debugging this, I fixed the syntax. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61029 Reviewed By: mruberry Differential Revision: D29571923 Pulled By: mrshenli fbshipit-source-id: bb8112be5302c6ec43151590b438b195a8f30a06
This commit is contained in:
committed by
Facebook GitHub Bot
parent
376dc500a9
commit
1262b2c4c6
@ -367,6 +367,7 @@ import sphinx.ext.doctest
|
||||
doctest_test_doctest_blocks = ''
|
||||
doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS
|
||||
doctest_global_setup = '''
|
||||
import torch
|
||||
try:
|
||||
import torchvision
|
||||
except ImportError:
|
||||
|
@ -147,23 +147,18 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
on those futures independently.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> def callback(fut):
|
||||
>>> print(f"RPC return value is {fut.wait()}.")
|
||||
>>>
|
||||
... print(f"RPC return value is {fut.wait()}.")
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> # The inserted callback will print the return value when
|
||||
>>> # receiving the response from "worker1"
|
||||
>>> cb_fut = fut.then(callback)
|
||||
>>> chain_cb_fut = cb_fut.then(
|
||||
>>> lambda x : print(f"Chained cb done. {x.wait()}")
|
||||
>>> )
|
||||
... lambda x : print(f"Chained cb done. {x.wait()}")
|
||||
... )
|
||||
>>> fut.set_result(5)
|
||||
>>>
|
||||
>>> # Outputs are:
|
||||
>>> # RPC return value is 5.
|
||||
>>> # Chained cb done. None
|
||||
RPC return value is 5.
|
||||
Chained cb done. None
|
||||
"""
|
||||
return cast(Future[S], super().then(callback))
|
||||
|
||||
@ -200,19 +195,14 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
for handling completion/waiting on those futures independently.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> def callback(fut):
|
||||
>>> print(f"This will run after the future has finished.")
|
||||
>>> print(fut.wait())
|
||||
>>>
|
||||
... print(f"This will run after the future has finished.")
|
||||
... print(fut.wait())
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> fut.add_done_callback(callback)
|
||||
>>> fut.set_result(5)
|
||||
>>>
|
||||
>>> # Outputs are:
|
||||
>>> This will run after the future has finished.
|
||||
>>> 5
|
||||
This will run after the future has finished.
|
||||
5
|
||||
"""
|
||||
super().add_done_callback(callback)
|
||||
|
||||
@ -239,20 +229,17 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
Example::
|
||||
>>> import threading
|
||||
>>> import time
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> def slow_set_future(fut, value):
|
||||
>>> time.sleep(0.5)
|
||||
>>> fut.set_result(value)
|
||||
>>>
|
||||
... time.sleep(0.5)
|
||||
... fut.set_result(value)
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> t = threading.Thread(
|
||||
>>> target=slow_set_future,
|
||||
>>> args=(fut, torch.ones(2) * 3)
|
||||
>>> )
|
||||
... target=slow_set_future,
|
||||
... args=(fut, torch.ones(2) * 3)
|
||||
... )
|
||||
>>> t.start()
|
||||
>>>
|
||||
>>> print(fut.wait()) # tensor([3., 3.])
|
||||
>>> print(fut.wait())
|
||||
tensor([3., 3.])
|
||||
>>> t.join()
|
||||
"""
|
||||
super().set_result(result)
|
||||
@ -268,15 +255,12 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
result (BaseException): the exception for this ``Future``.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> fut.set_exception(ValueError("foo"))
|
||||
>>> fut.wait()
|
||||
>>>
|
||||
>>> # Output:
|
||||
>>> # This will run after the future has finished.
|
||||
>>> ValueError: foo
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: foo
|
||||
"""
|
||||
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
|
||||
|
||||
@ -301,22 +285,16 @@ def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||||
in Futures.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> fut0 = torch.futures.Future()
|
||||
>>> fut1 = torch.futures.Future()
|
||||
>>>
|
||||
>>> fut = torch.futures.collect_all([fut0, fut1])
|
||||
>>>
|
||||
>>> fut0.set_result(0)
|
||||
>>> fut1.set_result(1)
|
||||
>>>
|
||||
>>> fut_list = fut.wait()
|
||||
>>> print(f"fut0 result = {fut_list[0].wait()}")
|
||||
fut0 result = 0
|
||||
>>> print(f"fut1 result = {fut_list[1].wait()}")
|
||||
>>> # outputs:
|
||||
>>> # fut0 result = 0
|
||||
>>> # fut1 result = 1
|
||||
fut1 result = 1
|
||||
"""
|
||||
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
|
||||
|
||||
|
Reference in New Issue
Block a user