[jit] make RRef type annotation available in Python (#33526)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33526

Test Plan: Imported from OSS

Differential Revision: D19988848

Pulled By: wanchaol

fbshipit-source-id: aeebc946d08b38dac0b656617bf395e86bcea558
This commit is contained in:
Wanchao Liang
2020-02-26 18:40:15 -08:00
committed by Facebook Github Bot
parent 2448c97a53
commit d494986171
5 changed files with 68 additions and 2 deletions

View File

@ -663,6 +663,37 @@ except ImportError:
return isinstance(ann, FinalInstance)
try:
from typing import TypeVar, Generic
T = TypeVar('T')
class RRef(Generic[T]):
__slots__ = ['__args__']
def __init__(self, types):
self.__args__ = types
def is_rref(ann):
return getattr(ann, "__origin__", None) is RRef
except ImportError:
class RRefInstance(object):
__slots__ = ['__args__']
def __init__(self, types):
self.__args__ = types
class RRefCls(object):
def __getitem__(self, types):
return RRefInstance(types)
RRef = RRefCls() # noqa: T484
def is_rref(ann):
return isinstance(ann, RRefInstance)
# allows BroadcastingList instance to be subscriptable
class BroadcastingListCls(object):
def __getitem__(self, types):

View File

@ -23,6 +23,7 @@
#include <torch/csrc/utils/six.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#endif
#include <ATen/core/function_schema.h>
@ -718,6 +719,15 @@ inline py::object toPyObject(IValue ivalue) {
py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()});
}
return std::move(py_dict);
} else if (ivalue.isRRef()) {
#ifdef USE_DISTRIBUTED
auto RRefPtr =
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
std::move(ivalue).toRRef());
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
#else
AT_ERROR("RRef is only supported with the distributed package");
#endif
} else if (ivalue.isObject()) {
const auto obj = std::move(ivalue).toObject();
if (obj->type()->is_module()) {

View File

@ -747,6 +747,9 @@ void initPythonIRBindings(PyObject* module_) {
.def(py::init([](TypePtr a) { return OptionalType::create(a); }))
.def_static("ofTensor", &OptionalType::ofTensor)
.def("getElementType", &OptionalType::getElementType);
py::class_<RRefType, Type, std::shared_ptr<RRefType>>(m, "RRefType")
.def(py::init([](TypePtr a) { return RRefType::create(a); }))
.def("getElementType", &RRefType::getElementType);
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
.def(py::init([](const std::string& qualified_name) {

View File

@ -5,10 +5,10 @@ import re
import torch
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional, _qualified_name, Any
is_optional, _qualified_name, Any, RRef, is_rref
from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
DeviceObjType
DeviceObjType, RRefType
from textwrap import dedent
from torch._six import builtins, PY2
@ -39,6 +39,7 @@ class EvalEnv(object):
'List': List,
'Dict': Dict,
'Optional': Optional,
'RRef': RRef,
}
def __init__(self, rcb):
@ -271,6 +272,8 @@ def ann_to_type(ann, resolver=None):
return OptionalType(ann_to_type(ann.__args__[0]))
else:
return OptionalType(ann_to_type(ann.__args__[1]))
elif is_rref(ann):
return RRefType(ann_to_type(ann.__args__[0]))
elif ann is float:
return FloatType.get()
elif ann is int:

View File

@ -230,3 +230,22 @@ class JitRpcTest(RpcAgentTestFixture):
module_with_rrefs = MyScriptModuleWithRRefs("worker{}".format(dst_rank))
res = module_with_rrefs()
self.assertEqual(res, torch.ones(2, 2) * 9)
@dist_init
def test_rref_python_annotation(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_var = rpc_return_rref("worker{}".format(dst_rank))
@torch.jit.ignore
def rref_python_annotation(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
@torch.jit.script
def rref_script_annotation(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_python_annotation(rref_var).to_here()
res = rref_script_annotation(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)