mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
[rpc] Remove template on RRef and add Type to RRef creation (#30630)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630 This remove template and all the specializations it have in rpc, we universally use IValue as the inner value since we support making python object to be hold inside IValue. This will also ensure that we have the correct type information when creating the RRef, we use the return type from the schema when creating userRRef and OwnerRRef, it will enable IValue to always have the correct type if the IValue is the RRef object (next PR) Test Plan: Imported from OSS Differential Revision: D19502235 fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
This commit is contained in:
committed by
Facebook Github Bot
parent
ef2d4e67d1
commit
b474c351dd
@ -8,6 +8,28 @@ namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
// PythonTypeResolver that inherits from Script::Resolver to
|
||||
// support resolving types together with ScriptTypeParser.
|
||||
struct PythonTypeResolver : public jit::script::Resolver {
|
||||
std::shared_ptr<jit::script::SugaredValue> resolveValue(
|
||||
const std::string& /* unused */,
|
||||
Function& /* unused */,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "RPC Type resolver does not need to resolve value");
|
||||
}
|
||||
|
||||
TypePtr resolveType(
|
||||
const std::string& name,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
if (name == "PyObject") {
|
||||
return PyObjectType::get();
|
||||
}
|
||||
auto python_cu = torch::jit::get_python_cu();
|
||||
return python_cu->get_type(name);
|
||||
}
|
||||
};
|
||||
|
||||
py::object getFunction(const py::object& module, const char* name) {
|
||||
py::object fn = module.attr(name);
|
||||
TORCH_CHECK(
|
||||
@ -28,6 +50,8 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
pySerialize_ = getFunction(module, "serialize");
|
||||
pyHandleException_ = getFunction(module, "_handle_exception");
|
||||
jitCompilationUnit_ = torch::jit::get_python_cu();
|
||||
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
|
||||
std::make_shared<PythonTypeResolver>());
|
||||
}
|
||||
|
||||
void PythonRpcHandler::cleanup() {
|
||||
@ -95,6 +119,10 @@ void PythonRpcHandler::handleException(const py::object& obj) {
|
||||
pyHandleException_(obj);
|
||||
}
|
||||
|
||||
TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
|
||||
return typeParser_->parseType(type_str);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user