mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] make RRef type annotation available in Python (#33526)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33526 Test Plan: Imported from OSS Differential Revision: D19988848 Pulled By: wanchaol fbshipit-source-id: aeebc946d08b38dac0b656617bf395e86bcea558
This commit is contained in:
committed by
Facebook Github Bot
parent
2448c97a53
commit
d494986171
@ -663,6 +663,37 @@ except ImportError:
|
|||||||
return isinstance(ann, FinalInstance)
|
return isinstance(ann, FinalInstance)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypeVar, Generic
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
class RRef(Generic[T]):
|
||||||
|
__slots__ = ['__args__']
|
||||||
|
|
||||||
|
def __init__(self, types):
|
||||||
|
self.__args__ = types
|
||||||
|
|
||||||
|
def is_rref(ann):
|
||||||
|
return getattr(ann, "__origin__", None) is RRef
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
class RRefInstance(object):
|
||||||
|
__slots__ = ['__args__']
|
||||||
|
|
||||||
|
def __init__(self, types):
|
||||||
|
self.__args__ = types
|
||||||
|
|
||||||
|
class RRefCls(object):
|
||||||
|
def __getitem__(self, types):
|
||||||
|
return RRefInstance(types)
|
||||||
|
|
||||||
|
RRef = RRefCls() # noqa: T484
|
||||||
|
|
||||||
|
def is_rref(ann):
|
||||||
|
return isinstance(ann, RRefInstance)
|
||||||
|
|
||||||
|
|
||||||
# allows BroadcastingList instance to be subscriptable
|
# allows BroadcastingList instance to be subscriptable
|
||||||
class BroadcastingListCls(object):
|
class BroadcastingListCls(object):
|
||||||
def __getitem__(self, types):
|
def __getitem__(self, types):
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
#include <torch/csrc/utils/six.h>
|
#include <torch/csrc/utils/six.h>
|
||||||
#ifdef USE_DISTRIBUTED
|
#ifdef USE_DISTRIBUTED
|
||||||
#include <torch/csrc/distributed/rpc/py_rref.h>
|
#include <torch/csrc/distributed/rpc/py_rref.h>
|
||||||
|
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
@ -718,6 +719,15 @@ inline py::object toPyObject(IValue ivalue) {
|
|||||||
py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()});
|
py_dict[toPyObject(IValue{pair.key()})] = toPyObject(IValue{pair.value()});
|
||||||
}
|
}
|
||||||
return std::move(py_dict);
|
return std::move(py_dict);
|
||||||
|
} else if (ivalue.isRRef()) {
|
||||||
|
#ifdef USE_DISTRIBUTED
|
||||||
|
auto RRefPtr =
|
||||||
|
c10::dynamic_intrusive_pointer_cast<torch::distributed::rpc::RRef>(
|
||||||
|
std::move(ivalue).toRRef());
|
||||||
|
return py::cast(torch::distributed::rpc::PyRRef(RRefPtr));
|
||||||
|
#else
|
||||||
|
AT_ERROR("RRef is only supported with the distributed package");
|
||||||
|
#endif
|
||||||
} else if (ivalue.isObject()) {
|
} else if (ivalue.isObject()) {
|
||||||
const auto obj = std::move(ivalue).toObject();
|
const auto obj = std::move(ivalue).toObject();
|
||||||
if (obj->type()->is_module()) {
|
if (obj->type()->is_module()) {
|
||||||
|
@ -747,6 +747,9 @@ void initPythonIRBindings(PyObject* module_) {
|
|||||||
.def(py::init([](TypePtr a) { return OptionalType::create(a); }))
|
.def(py::init([](TypePtr a) { return OptionalType::create(a); }))
|
||||||
.def_static("ofTensor", &OptionalType::ofTensor)
|
.def_static("ofTensor", &OptionalType::ofTensor)
|
||||||
.def("getElementType", &OptionalType::getElementType);
|
.def("getElementType", &OptionalType::getElementType);
|
||||||
|
py::class_<RRefType, Type, std::shared_ptr<RRefType>>(m, "RRefType")
|
||||||
|
.def(py::init([](TypePtr a) { return RRefType::create(a); }))
|
||||||
|
.def("getElementType", &RRefType::getElementType);
|
||||||
|
|
||||||
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
||||||
.def(py::init([](const std::string& qualified_name) {
|
.def(py::init([](const std::string& qualified_name) {
|
||||||
|
@ -5,10 +5,10 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
|
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
|
||||||
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
|
||||||
is_optional, _qualified_name, Any
|
is_optional, _qualified_name, Any, RRef, is_rref
|
||||||
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
||||||
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
|
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType, \
|
||||||
DeviceObjType
|
DeviceObjType, RRefType
|
||||||
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from torch._six import builtins, PY2
|
from torch._six import builtins, PY2
|
||||||
@ -39,6 +39,7 @@ class EvalEnv(object):
|
|||||||
'List': List,
|
'List': List,
|
||||||
'Dict': Dict,
|
'Dict': Dict,
|
||||||
'Optional': Optional,
|
'Optional': Optional,
|
||||||
|
'RRef': RRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, rcb):
|
def __init__(self, rcb):
|
||||||
@ -271,6 +272,8 @@ def ann_to_type(ann, resolver=None):
|
|||||||
return OptionalType(ann_to_type(ann.__args__[0]))
|
return OptionalType(ann_to_type(ann.__args__[0]))
|
||||||
else:
|
else:
|
||||||
return OptionalType(ann_to_type(ann.__args__[1]))
|
return OptionalType(ann_to_type(ann.__args__[1]))
|
||||||
|
elif is_rref(ann):
|
||||||
|
return RRefType(ann_to_type(ann.__args__[0]))
|
||||||
elif ann is float:
|
elif ann is float:
|
||||||
return FloatType.get()
|
return FloatType.get()
|
||||||
elif ann is int:
|
elif ann is int:
|
||||||
|
@ -230,3 +230,22 @@ class JitRpcTest(RpcAgentTestFixture):
|
|||||||
module_with_rrefs = MyScriptModuleWithRRefs("worker{}".format(dst_rank))
|
module_with_rrefs = MyScriptModuleWithRRefs("worker{}".format(dst_rank))
|
||||||
res = module_with_rrefs()
|
res = module_with_rrefs()
|
||||||
self.assertEqual(res, torch.ones(2, 2) * 9)
|
self.assertEqual(res, torch.ones(2, 2) * 9)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_rref_python_annotation(self):
|
||||||
|
n = self.rank + 1
|
||||||
|
dst_rank = n % self.world_size
|
||||||
|
rref_var = rpc_return_rref("worker{}".format(dst_rank))
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def rref_python_annotation(rref_var):
|
||||||
|
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||||
|
return rref_var
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def rref_script_annotation(rref_var):
|
||||||
|
# type: (RRef[Tensor]) -> Tensor
|
||||||
|
return rref_python_annotation(rref_var).to_here()
|
||||||
|
|
||||||
|
res = rref_script_annotation(rref_var)
|
||||||
|
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||||
|
Reference in New Issue
Block a user