[RPC] Add option to make rref.get_type not block. (#50977)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50977

Adds a `blocking` flag that can be set to False to make this API return a `Future` to the type. This is to make this function non-blocking, mostly for a future change that will allow `rref.rpc_async()` to be completely non-blocking (it currently calls and waits for this function that issues an RPC in-line).
ghstack-source-id: 121021433

Test Plan: Modified UT

Reviewed By: mrshenli

Differential Revision: D25944582

fbshipit-source-id: e3b48a52af2d4578551a30ba6838927b489b1c03
This commit is contained in:
Rohan Varma
2021-02-04 20:15:25 -08:00
committed by Facebook GitHub Bot
parent 716a8c2153
commit c3f2f3294e
5 changed files with 123 additions and 30 deletions

View File

@ -7,6 +7,7 @@ import threading
from typing import Generic, TypeVar, Set, Any
import torch
from torch.futures import Future
from torch._C._distributed_rpc import (
PyRRef,
@ -361,17 +362,31 @@ def _to_worker_info(to):
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
def _rref_typeof_on_owner(rref):
return type(rref.local_value())
def _rref_typeof_on_owner(rref, blocking=True):
rref_type = type(rref.local_value())
if blocking:
return rref_type
else:
# Wrap result into a completed Future. This is so that if blocking=`False`
# is specified, we return a future regardless of if this call is on user
# or owner.
future = Future[type]()
future.set_result(rref_type)
return future
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT):
return rpc_sync(
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
fut = rpc_async(
rref.owner(),
_rref_typeof_on_owner,
args=(rref,),
timeout=timeout
)
if blocking:
return fut.wait()
else:
return fut
T = TypeVar("T")