mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RPC] Add option to make rref.get_type not block. (#50977)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50977 Adds a `blocking` flag that can be set to False to make this API return a `Future` to the type. This is to make this function non-blocking, mostly for a future change that will allow `rref.rpc_async()` to be completely non-blocking (it currently calls and waits for this function that issues an RPC in-line). ghstack-source-id: 121021433 Test Plan: Modified UT Reviewed By: mrshenli Differential Revision: D25944582 fbshipit-source-id: e3b48a52af2d4578551a30ba6838927b489b1c03
This commit is contained in:
committed by
Facebook GitHub Bot
parent
716a8c2153
commit
c3f2f3294e
@ -7,6 +7,7 @@ import threading
|
||||
from typing import Generic, TypeVar, Set, Any
|
||||
|
||||
import torch
|
||||
from torch.futures import Future
|
||||
|
||||
from torch._C._distributed_rpc import (
|
||||
PyRRef,
|
||||
@ -361,17 +362,31 @@ def _to_worker_info(to):
|
||||
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
|
||||
|
||||
|
||||
def _rref_typeof_on_owner(rref):
|
||||
return type(rref.local_value())
|
||||
def _rref_typeof_on_owner(rref, blocking=True):
|
||||
rref_type = type(rref.local_value())
|
||||
if blocking:
|
||||
return rref_type
|
||||
else:
|
||||
# Wrap result into a completed Future. This is so that if blocking=`False`
|
||||
# is specified, we return a future regardless of if this call is on user
|
||||
# or owner.
|
||||
future = Future[type]()
|
||||
future.set_result(rref_type)
|
||||
return future
|
||||
|
||||
|
||||
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT):
|
||||
return rpc_sync(
|
||||
def _rref_typeof_on_user(rref, timeout=UNSET_RPC_TIMEOUT, blocking=True):
|
||||
fut = rpc_async(
|
||||
rref.owner(),
|
||||
_rref_typeof_on_owner,
|
||||
args=(rref,),
|
||||
timeout=timeout
|
||||
)
|
||||
if blocking:
|
||||
return fut.wait()
|
||||
else:
|
||||
return fut
|
||||
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
Reference in New Issue
Block a user