mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Initial use RRef in TorchScript (#33190)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33190 This enable the initial RRef type to be used inside TorchScript, user could pass a python RRef into a torchscript function and call to_here inside. Specifically, this PR: - Add RRef schema type parsing - Add python interop for RRef in Python and into JIT - register to_here op in register_distributed_ops More support for RRef in TorchScript will be added in future PRs Test Plan: Imported from OSS Differential Revision: D19871244 Pulled By: wanchaol fbshipit-source-id: 7eca6c491a84666b261c70806254b705603bd663
This commit is contained in:
committed by
Facebook Github Bot
parent
b2c5896432
commit
93179b1c1c
@ -88,8 +88,9 @@ if __name__ == '__main__':
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
if "torch.classes" in line:
|
||||
if "torch.classes" in line or "RRef" in line:
|
||||
# TODO Fix type __torch__.torch.classes.xxx
|
||||
# TODO Delete RRef special case after add the RRef type
|
||||
continue
|
||||
s = parse_schema(line.strip())
|
||||
slist = new_schema_dict.get(s.name, [])
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from torch.testing._internal.distributed.rpc.rpc_test import RpcTest
|
||||
from torch.testing._internal.distributed.rpc.rpc_test import RpcTest, RpcJitTest
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN, run_tests
|
||||
|
||||
@ -14,5 +14,12 @@ class RpcTestWithSpawn(MultiProcessTestCase, RpcTest):
|
||||
super(RpcTestWithSpawn, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
|
||||
class RpcJitTestWithSpawn(MultiProcessTestCase, RpcJitTest):
|
||||
|
||||
def setUp(self):
|
||||
super(RpcJitTestWithSpawn, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -156,6 +156,13 @@ PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
||||
return PyRRef(std::move(rref));
|
||||
}
|
||||
|
||||
c10::IValue PyRRef::toIValue() {
|
||||
// cast to RRefInterface to hold it into IValue
|
||||
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
|
||||
return IValue(rrefPtr);
|
||||
}
|
||||
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
@ -22,6 +22,7 @@ class PyRRef {
|
||||
std::string str() const;
|
||||
py::tuple pickle() const;
|
||||
static PyRRef unpickle(const py::tuple& t);
|
||||
c10::IValue toIValue();
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<RRef> rref_;
|
||||
|
@ -21,6 +21,9 @@
|
||||
#include <torch/csrc/utils/auto_gil.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/six.h>
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/py_rref.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <c10/util/Exception.h>
|
||||
@ -565,7 +568,13 @@ inline IValue toIValue(
|
||||
c10::str("Cannot cast ", py::str(obj), " to ", type->python_str()));
|
||||
}
|
||||
}
|
||||
case TypeKind::RRefType:
|
||||
case TypeKind::RRefType: {
|
||||
#ifdef USE_DISTRIBUTED
|
||||
return obj.cast<torch::distributed::rpc::PyRRef>().toIValue();
|
||||
#else
|
||||
AT_ERROR("RRef is only supported with the distributed package");
|
||||
#endif
|
||||
}
|
||||
case TypeKind::GeneratorType:
|
||||
case TypeKind::VarType:
|
||||
case TypeKind::FutureType:
|
||||
|
@ -1,12 +1,17 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include "torch/csrc/jit/operator.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <torch/csrc/distributed/autograd/context/container.h>
|
||||
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
||||
|
||||
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
||||
|
||||
using at::Scalar;
|
||||
using at::Tensor;
|
||||
namespace dist_autograd = torch::distributed::autograd;
|
||||
namespace dist_rpc = torch::distributed::rpc;
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -23,6 +28,39 @@ at::Tensor optional_to_tensor(c10::optional<at::Tensor> v) {
|
||||
return v.has_value() ? *v : at::Tensor();
|
||||
}
|
||||
|
||||
c10::OperatorOptions aliasAnalysisFromSchema() {
|
||||
c10::OperatorOptions result;
|
||||
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
return result;
|
||||
}
|
||||
|
||||
RegisterOperators reg_rpc_ops({
|
||||
Operator(
|
||||
"aten::to_here(RRef(t) self) -> t",
|
||||
[](Stack& stack) {
|
||||
auto rref = pop(stack).toRRef();
|
||||
IValue res;
|
||||
if (rref->isOwner()) {
|
||||
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
|
||||
->getValue();
|
||||
} else {
|
||||
res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
|
||||
->toHere();
|
||||
}
|
||||
push(stack, std::move(res));
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::is_owner(RRef(t) self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto rref = pop(stack).toRRef();
|
||||
push(stack, rref->isOwner());
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
|
||||
auto reg_distributed_ops =
|
||||
torch::RegisterOperators()
|
||||
.op("aten::get_gradients(int context_id) -> Dict(Tensor, Tensor)",
|
||||
|
@ -21,6 +21,7 @@ using c10::ListType;
|
||||
using c10::NoneType;
|
||||
using c10::NumberType;
|
||||
using c10::OptionalType;
|
||||
using c10::RRefType;
|
||||
using c10::StringType;
|
||||
using c10::Symbol;
|
||||
using c10::QSchemeType;
|
||||
@ -201,6 +202,14 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
value = FutureType::create(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
|
||||
L.next(); // RRef
|
||||
L.expect('(');
|
||||
auto p = parseType();
|
||||
auto subtype = std::move(p.first);
|
||||
auto subalias = std::move(p.second);
|
||||
L.expect(')');
|
||||
value = RRefType::create(subtype);
|
||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
||||
L.next();
|
||||
value = TensorType::get();
|
||||
|
@ -57,6 +57,14 @@ TypePtr ScriptTypeParser::subscriptToType(
|
||||
}
|
||||
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
|
||||
return FutureType::create(elem_type);
|
||||
} else if (typeName == "RRef") {
|
||||
if (subscript.subscript_exprs().size() != 1) {
|
||||
throw ErrorReport(subscript)
|
||||
<< " expected exactly one element type but found "
|
||||
<< subscript.subscript_exprs().size();
|
||||
}
|
||||
auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
|
||||
return RRefType::create(elem_type);
|
||||
} else if (typeName == "Dict") {
|
||||
if (subscript.subscript_exprs().size() != 2) {
|
||||
throw ErrorReport(subscript)
|
||||
|
@ -1710,3 +1710,38 @@ class RpcTest(RpcAgentTestFixture):
|
||||
AttributeError, "RPC pickler does not serialize"
|
||||
):
|
||||
rpc.rpc_sync(callee_worker, foo_add, args=())
|
||||
self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 0),
|
||||
"Pytorch distributed rpc package " "does not support python2",
|
||||
)
|
||||
class RpcJitTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_rref_as_arg(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
rref_var = rpc_return_rref("worker{}".format(dst_rank))
|
||||
|
||||
@torch.jit.script
|
||||
def rref_tensor_to_here(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_var.to_here()
|
||||
|
||||
res = rref_tensor_to_here(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
|
||||
@dist_init
|
||||
def test_rref_is_owner(self):
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
rref_var = rpc_return_rref("worker{}".format(dst_rank))
|
||||
|
||||
@torch.jit.script
|
||||
def rref_tensor_is_owner(rref_var):
|
||||
# type: (RRef[Tensor]) -> bool
|
||||
return rref_var.is_owner()
|
||||
|
||||
res = rref_tensor_is_owner(rref_var)
|
||||
self.assertEqual(res, False)
|
||||
|
Reference in New Issue
Block a user