mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Add Python RRef as args and return value (#25499)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499 See #23110 for model parallel design details, and #26759 for the RRef protocol. This commit add support for using RRef as Python UDF arguments and return value. RRefs can now be shared from owner to user, from user to owner, or from user to user. Limitations: 1. No implicit type conversion yet. (#27099) 2. No failure handling and retry. (#26116) 3. UDF is not yet blocked until all RRefs are confirmed. (#27098) 4. Internal RRef control messages are not idempotent yet. (#26116) 5. Cannot delete RRefs correctly when there are circular dependencies. (#27096) Main changes: 1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations. 2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages. 3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`. 4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure. 5. Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs. 6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`. Test Plan: Imported from OSS buck test mode/dev-nosan //caffe2/test:rpc_fork Differential Revision: D17184146 Pulled By: mrshenli fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
This commit is contained in:
committed by
Facebook Github Bot
parent
8fe5dcf699
commit
2486b0ba82
@ -10,6 +10,7 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
py::module::import("torch.distributed.internal_rpc_utils");
|
||||
runUDFFunction_ = module.attr("run_python_udf_internal");
|
||||
loadResultFunction_ = module.attr("load_python_udf_result_internal");
|
||||
serializeFunction_ = module.attr("serialize");
|
||||
}
|
||||
|
||||
PythonRpcHandler& PythonRpcHandler::getInstance() {
|
||||
@ -24,7 +25,8 @@ std::vector<char> PythonRpcHandler::generatePythonUDFResult(
|
||||
AutoGIL ag;
|
||||
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
|
||||
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
|
||||
py::tuple pres = runUDFFunction_(pargs, requestTensorTable);
|
||||
py::tuple pres =
|
||||
serializeFunction_(runUDFFunction_(pargs, requestTensorTable));
|
||||
const auto& presStr = pres[0].cast<std::string>();
|
||||
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
|
||||
std::vector<char> payload(presStr.begin(), presStr.end());
|
||||
@ -40,6 +42,26 @@ py::object PythonRpcHandler::loadPythonUDFResult(
|
||||
return loadResultFunction_(pargs, tensorTable);
|
||||
}
|
||||
|
||||
py::object PythonRpcHandler::runPythonUDF(
|
||||
const SerializedPyObj& serializedObj) {
|
||||
AutoGIL ag;
|
||||
return runUDFFunction_(
|
||||
py::bytes(serializedObj.payload_), serializedObj.tensors_);
|
||||
}
|
||||
|
||||
SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
|
||||
AutoGIL ag;
|
||||
py::tuple t = serializeFunction_(obj);
|
||||
return SerializedPyObj(
|
||||
t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
|
||||
}
|
||||
|
||||
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
|
||||
AutoGIL ag;
|
||||
return loadResultFunction_(
|
||||
py::bytes(serializedObj.payload_), serializedObj.tensors_);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user