mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] make RRef type annotation available in Python (#33526)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33526 Test Plan: Imported from OSS Differential Revision: D19988848 Pulled By: wanchaol fbshipit-source-id: aeebc946d08b38dac0b656617bf395e86bcea558
This commit is contained in:
committed by
Facebook Github Bot
parent
2448c97a53
commit
d494986171
@ -663,6 +663,37 @@ except ImportError:
|
||||
return isinstance(ann, FinalInstance)
|
||||
|
||||
|
||||
try:
|
||||
from typing import TypeVar, Generic
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class RRef(Generic[T]):
|
||||
__slots__ = ['__args__']
|
||||
|
||||
def __init__(self, types):
|
||||
self.__args__ = types
|
||||
|
||||
def is_rref(ann):
|
||||
return getattr(ann, "__origin__", None) is RRef
|
||||
|
||||
except ImportError:
|
||||
class RRefInstance(object):
|
||||
__slots__ = ['__args__']
|
||||
|
||||
def __init__(self, types):
|
||||
self.__args__ = types
|
||||
|
||||
class RRefCls(object):
|
||||
def __getitem__(self, types):
|
||||
return RRefInstance(types)
|
||||
|
||||
RRef = RRefCls() # noqa: T484
|
||||
|
||||
def is_rref(ann):
|
||||
return isinstance(ann, RRefInstance)
|
||||
|
||||
|
||||
# allows BroadcastingList instance to be subscriptable
|
||||
class BroadcastingListCls(object):
|
||||
def __getitem__(self, types):
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <torch/csrc/utils/six.h>
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/py_rref.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
@ -718,6 +719,15 @@ inline py::object toPyObject(IValue ivalue) {
|
||||
py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()});
|
||||
}
|
||||
return std::move(py_dict);
|
||||
} else if (ivalue.isRRef()) {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
auto RRefPtr =
|
||||
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
|
||||
std::move(ivalue).toRRef());
|
||||
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
|
||||
#else
|
||||
AT_ERROR("RRef is only supported with the distributed package");
|
||||
#endif
|
||||
} else if (ivalue.isObject()) {
|
||||
const auto obj = std::move(ivalue).toObject();
|
||||
if (obj->type()->is_module()) {
|
||||
|
@ -747,6 +747,9 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
.def(py::init([](TypePtr a) { return OptionalType::create(a); }))
|
||||
.def_static("ofTensor", &OptionalType::ofTensor)
|
||||
.def("getElementType", &OptionalType::getElementType);
|
||||
py::class_<RRefType, Type, std::shared_ptr<RRefType>>(m, "RRefType")
|
||||
.def(py::init([](TypePtr a) { return RRefType::create(a); }))
|
||||
.def("getElementType", &RRefType::getElementType);
|
||||
|
||||
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
||||
.def(py::init([](const std::string& qualified_name) {
|
||||
|
@ -5,10 +5,10 @@ import re
|
||||
import torch
|
||||
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
|
||||
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
||||
is_optional, _qualified_name, Any
|
||||
is_optional, _qualified_name, Any, RRef, is_rref
|
||||
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
||||
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
|
||||
DeviceObjType
|
||||
DeviceObjType, RRefType
|
||||
|
||||
from textwrap import dedent
|
||||
from torch._six import builtins, PY2
|
||||
@ -39,6 +39,7 @@ class EvalEnv(object):
|
||||
'List': List,
|
||||
'Dict': Dict,
|
||||
'Optional': Optional,
|
||||
'RRef': RRef,
|
||||
}
|
||||
|
||||
def __init__(self, rcb):
|
||||
@ -271,6 +272,8 @@ def ann_to_type(ann, resolver=None):
|
||||
return OptionalType(ann_to_type(ann.__args__[0]))
|
||||
else:
|
||||
return OptionalType(ann_to_type(ann.__args__[1]))
|
||||
elif is_rref(ann):
|
||||
return RRefType(ann_to_type(ann.__args__[0]))
|
||||
elif ann is float:
|
||||
return FloatType.get()
|
||||
elif ann is int:
|
||||
|
@ -230,3 +230,22 @@ class JitRpcTest(RpcAgentTestFixture):
|
||||
module_with_rrefs = MyScriptModuleWithRRefs("worker{}".format(dst_rank))
|
||||
res = module_with_rrefs()
|
||||
self.assertEqual(res, torch.ones(2, 2) * 9)
|
||||
|
||||
@dist_init
|
||||
def test_rref_python_annotation(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
rref_var = rpc_return_rref("worker{}".format(dst_rank))
|
||||
|
||||
@torch.jit.ignore
|
||||
def rref_python_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
@torch.jit.script
|
||||
def rref_script_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_python_annotation(rref_var).to_here()
|
||||
|
||||
res = rref_script_annotation(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
|
Reference in New Issue
Block a user