[jit] Initial use RRef in TorchScript (#33190)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33190

This enable the initial RRef type to be used inside TorchScript, user
could pass a python RRef into a torchscript function and call to_here
inside. Specifically, this PR:

- Add RRef schema type parsing
- Add python interop for RRef in Python and into JIT
- register to_here op in register_distributed_ops

More support for RRef in TorchScript will be added in future PRs

Test Plan: Imported from OSS

Differential Revision: D19871244

Pulled By: wanchaol

fbshipit-source-id: 7eca6c491a84666b261c70806254b705603bd663
This commit is contained in:
Wanchao Liang
2020-02-13 20:13:10 -08:00
committed by Facebook Github Bot
parent b2c5896432
commit 93179b1c1c
9 changed files with 118 additions and 3 deletions

View File

@ -88,8 +88,9 @@ if __name__ == '__main__':
line = f.readline()
if not line:
break
if "torch.classes" in line:
if "torch.classes" in line or "RRef" in line:
# TODO Fix type __torch__.torch.classes.xxx
# TODO Delete RRef special case after add the RRef type
continue
s = parse_schema(line.strip())
slist = new_schema_dict.get(s.name, [])

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
from torch.testing._internal.distributed.rpc.rpc_test import RpcTest
from torch.testing._internal.distributed.rpc.rpc_test import RpcTest, RpcJitTest
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN, run_tests
@ -14,5 +14,12 @@ class RpcTestWithSpawn(MultiProcessTestCase, RpcTest):
super(RpcTestWithSpawn, self).setUp()
self._spawn_processes()
@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
class RpcJitTestWithSpawn(MultiProcessTestCase, RpcJitTest):
def setUp(self):
super(RpcJitTestWithSpawn, self).setUp()
self._spawn_processes()
if __name__ == '__main__':
run_tests()

View File

@ -156,6 +156,13 @@ PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
return PyRRef(std::move(rref));
}
c10::IValue PyRRef::toIValue() {
// cast to RRefInterface to hold it into IValue
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
return IValue(rrefPtr);
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -22,6 +22,7 @@ class PyRRef {
std::string str() const;
py::tuple pickle() const;
static PyRRef unpickle(const py::tuple& t);
c10::IValue toIValue();
private:
c10::intrusive_ptr<RRef> rref_;

View File

@ -21,6 +21,9 @@
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/six.h>
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/py_rref.h>
#endif
#include <ATen/core/function_schema.h>
#include <c10/util/Exception.h>
@ -565,7 +568,13 @@ inline IValue toIValue(
c10::str("Cannot cast ", py::str(obj), " to ", type->python_str()));
}
}
case TypeKind::RRefType:
case TypeKind::RRefType: {
#ifdef USE_DISTRIBUTED
return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
#else
AT_ERROR("RRef is only supported with the distributed package");
#endif
}
case TypeKind::GeneratorType:
case TypeKind::VarType:
case TypeKind::FutureType:

View File

@ -1,12 +1,17 @@
#include <ATen/ATen.h>
#include "torch/csrc/jit/operator.h"
#include "torch/csrc/jit/custom_operator.h"
#include <ATen/core/op_registration/op_registration.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
using at::Scalar;
using at::Tensor;
namespace dist_autograd = torch::distributed::autograd;
namespace dist_rpc = torch::distributed::rpc;
namespace torch {
namespace jit {
@ -23,6 +28,39 @@ at::Tensor optional_to_tensor(c10::optional<at::Tensor> v) {
return v.has_value() ? *v : at::Tensor();
}
c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
}
RegisterOperators reg_rpc_ops({
Operator(
"aten::to_here(RRef(t) self) -> t",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
IValue res;
if (rref->isOwner()) {
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
->getValue();
} else {
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
->toHere();
}
push(stack, std::move(res));
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::is_owner(RRef(t) self) -> bool",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(stack, rref->isOwner());
return 0;
},
aliasAnalysisFromSchema()),
});
auto reg_distributed_ops =
torch::RegisterOperators()
.op("aten::get_gradients(int context_id) -> Dict(Tensor, Tensor)",

View File

@ -21,6 +21,7 @@ using c10::ListType;
using c10::NoneType;
using c10::NumberType;
using c10::OptionalType;
using c10::RRefType;
using c10::StringType;
using c10::Symbol;
using c10::QSchemeType;
@ -201,6 +202,14 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto subalias = std::move(p.second);
L.expect(')');
value = FutureType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef
L.expect('(');
auto p = parseType();
auto subtype = std::move(p.first);
auto subalias = std::move(p.second);
L.expect(')');
value = RRefType::create(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next();
value = TensorType::get();

View File

@ -57,6 +57,14 @@ TypePtr ScriptTypeParser::subscriptToType(
}
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
return FutureType::create(elem_type);
} else if (typeName == "RRef") {
if (subscript.subscript_exprs().size() != 1) {
throw ErrorReport(subscript)
<< " expected exactly one element type but found "
<< subscript.subscript_exprs().size();
}
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
return RRefType::create(elem_type);
} else if (typeName == "Dict") {
if (subscript.subscript_exprs().size() != 2) {
throw ErrorReport(subscript)

View File

@ -1710,3 +1710,38 @@ class RpcTest(RpcAgentTestFixture):
AttributeError, "RPC pickler does not serialize"
):
rpc.rpc_sync(callee_worker, foo_add, args=())
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
@unittest.skipIf(
sys.version_info < (3, 0),
"Pytorch distributed rpc package " "does not support python2",
)
class RpcJitTest(RpcAgentTestFixture):
@dist_init
def test_rref_as_arg(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_var = rpc_return_rref("worker{}".format(dst_rank))
@torch.jit.script
def rref_tensor_to_here(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_var.to_here()
res = rref_tensor_to_here(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
@dist_init
def test_rref_is_owner(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_var = rpc_return_rref("worker{}".format(dst_rank))
@torch.jit.script
def rref_tensor_is_owner(rref_var):
# type: (RRef[Tensor]) -> bool
return rref_var.is_owner()
res = rref_tensor_is_owner(rref_var)
self.assertEqual(res, False)