[rpc] Remove template on RRef and add Type to RRef creation (#30630)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630

This remove template and all the specializations it have in rpc, we
universally use IValue as the inner value since we support making python
object to be hold inside IValue.

This will also ensure that we have the correct type information when
creating the RRef, we use the return type from the schema when creating
userRRef and OwnerRRef, it will enable IValue to always have the correct
type if the IValue is the RRef object (next PR)

Test Plan: Imported from OSS

Differential Revision: D19502235

fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
This commit is contained in:
Yanli Zhao
2020-01-23 21:09:23 -08:00
committed by Facebook Github Bot
parent ef2d4e67d1
commit b474c351dd
9 changed files with 219 additions and 238 deletions

View File

@ -8,6 +8,28 @@ namespace rpc {
namespace {
// PythonTypeResolver that inherits from Script::Resolver to
// support resolving types together with ScriptTypeParser.
struct PythonTypeResolver : public jit::script::Resolver {
std::shared_ptr<jit::script::SugaredValue> resolveValue(
const std::string& /* unused */,
Function& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");
}
TypePtr resolveType(
const std::string& name,
const jit::SourceRange& /* unused */) override {
if (name == "PyObject") {
return PyObjectType::get();
}
auto python_cu = torch::jit::get_python_cu();
return python_cu->get_type(name);
}
};
py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
TORCH_CHECK(
@ -28,6 +50,8 @@ PythonRpcHandler::PythonRpcHandler() {
pySerialize_ = getFunction(module, "serialize");
pyHandleException_ = getFunction(module, "_handle_exception");
jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>());
}
void PythonRpcHandler::cleanup() {
@ -95,6 +119,10 @@ void PythonRpcHandler::handleException(const py::object& obj) {
pyHandleException_(obj);
}
TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
return typeParser_->parseType(type_str);
}
} // namespace rpc
} // namespace distributed
} // namespace torch