mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* getType() is deprecated, replace it with new/proper static method. These are backwards compatible with old pybind11 versions we support. So break this off before we upgrade to pybind11 3.0 where these methods are dropped in #154115 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154798 Approved by: https://github.com/jansel, https://github.com/cyyever
335 lines
12 KiB
C++
335 lines
12 KiB
C++
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
|
|
|
#include <torch/csrc/autograd/profiler.h>
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/context/context.h>
|
|
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
|
|
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
#include <torch/csrc/distributed/rpc/python_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_resp.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
|
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
#include <torch/csrc/jit/python/python_ivalue.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace torch::distributed::rpc {
|
|
|
|
using namespace torch::distributed::autograd;
|
|
|
|
namespace {
|
|
|
|
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
|
|
RpcCommandBase& rpc,
|
|
const MessageType& messageType) {
|
|
switch (messageType) {
|
|
case MessageType::PYTHON_CALL: {
|
|
auto& pc = static_cast<PythonCall&>(rpc);
|
|
return std::make_unique<UnpickledPythonCall>(
|
|
pc.serializedPyObj(), pc.isAsyncExecution());
|
|
}
|
|
case MessageType::PYTHON_REMOTE_CALL: {
|
|
auto& prc = static_cast<PythonRemoteCall&>(rpc);
|
|
return std::make_unique<UnpickledPythonRemoteCall>(
|
|
prc.serializedPyObj(),
|
|
prc.retRRefId(),
|
|
prc.retForkId(),
|
|
prc.isAsyncExecution());
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
|
// Deserialize the wrapped RPC if it contains Python UDF
|
|
auto& rwa = static_cast<RpcWithAutograd&>(rpc);
|
|
auto& wrappedRpc = rwa.wrappedRpc();
|
|
auto pythonRpc = deserializePythonRpcCommandReference(
|
|
wrappedRpc, rwa.wrappedMessageType());
|
|
if (pythonRpc) {
|
|
rwa.setWrappedRpc(std::move(pythonRpc));
|
|
}
|
|
return nullptr;
|
|
}
|
|
case MessageType::RUN_WITH_PROFILING_REQ: {
|
|
// Deserialize wrapped RPC if it contains python call
|
|
auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
|
|
auto& wrappedRpc = rpcWithProfilingReq.wrappedRpc();
|
|
auto pythonRpc = deserializePythonRpcCommandReference(
|
|
wrappedRpc, rpcWithProfilingReq.wrappedMessageType());
|
|
if (pythonRpc) {
|
|
rpcWithProfilingReq.setWrappedRpc(std::move(pythonRpc));
|
|
}
|
|
return nullptr;
|
|
}
|
|
default: {
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
SerializedPyObj serializePyObject(IValue value) {
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
// Need this GIL to guard jit::toPyObj and destruct its returned
|
|
// py::object
|
|
py::gil_scoped_acquire acquire;
|
|
try {
|
|
return pythonRpcHandler.serialize(jit::toPyObject(std::move(value)));
|
|
} catch (py::error_already_set& e) {
|
|
// py::error_already_set requires GIL to destruct, take special care.
|
|
std::string err_msg = e.what();
|
|
e.restore();
|
|
PyErr_Clear();
|
|
throw std::runtime_error(err_msg);
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runPythonFunction(
|
|
const py::object& function,
|
|
const std::vector<c10::Stream>& streams,
|
|
bool isAsyncExecution) const {
|
|
c10::MultiStreamGuard guard(streams);
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
py::gil_scoped_acquire acquire;
|
|
|
|
py::object result;
|
|
try {
|
|
result = pythonRpcHandler.runPythonUdf(function);
|
|
} catch (py::error_already_set& e) {
|
|
// py::error_already_set requires GIL to destruct, take special care.
|
|
auto future =
|
|
asFuture(std::make_exception_ptr(std::runtime_error(e.what())));
|
|
e.restore();
|
|
PyErr_Clear();
|
|
return future;
|
|
} catch (std::exception& e) {
|
|
return asFuture(std::current_exception());
|
|
}
|
|
|
|
// After sync execution or failed async execution return the value as-is.
|
|
if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) {
|
|
return asFuture(
|
|
c10::ivalue::ConcretePyObjectHolder::create(result),
|
|
at::PyObjectType::get());
|
|
}
|
|
|
|
try {
|
|
return result.cast<jit::PythonFutureWrapper&>().fut;
|
|
} catch (const py::cast_error& e) {
|
|
auto type = py::type::handle_of(result);
|
|
auto errMsg = c10::str(
|
|
e.what(),
|
|
". Functions decorated with @rpc.async_function must return a "
|
|
"torch.futures.Future object, but got ",
|
|
type.attr("__module__").cast<std::string>(),
|
|
".",
|
|
type.attr("__qualname__").cast<std::string>());
|
|
return asFuture(std::make_exception_ptr(std::runtime_error(errMsg)));
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
|
|
deserializePythonRpcCommand(
|
|
std::unique_ptr<RpcCommandBase> rpc,
|
|
const MessageType& messageType) const {
|
|
auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
|
|
return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
|
|
RpcCommandBase& rpc,
|
|
const std::vector<c10::Stream>& streams) const {
|
|
auto& scriptCall = static_cast<ScriptCall&>(rpc);
|
|
|
|
c10::intrusive_ptr<JitFuture> future;
|
|
if (scriptCall.hasOp()) {
|
|
future = runJitOperator(*scriptCall.op(), scriptCall.stackRef(), streams);
|
|
} else {
|
|
future = runJitFunction(
|
|
scriptCall.qualifiedName(),
|
|
scriptCall.stackRef(),
|
|
streams,
|
|
scriptCall.isAsyncExecution());
|
|
}
|
|
|
|
return future->then(
|
|
[](JitFuture& jitFuture) {
|
|
return withStorages(ScriptResp(jitFuture.value()).toMessage());
|
|
},
|
|
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
|
|
RpcCommandBase& rpc,
|
|
const std::vector<c10::Stream>& streams) const {
|
|
auto& upc = static_cast<UnpickledPythonCall&>(rpc);
|
|
auto future =
|
|
runPythonFunction(upc.pythonUdf(), streams, upc.isAsyncExecution());
|
|
|
|
return future->then(
|
|
[](JitFuture& future) {
|
|
return withStorages(
|
|
PythonResp(serializePyObject(future.value())).toMessage());
|
|
},
|
|
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
|
|
RpcCommandBase& rpc,
|
|
const std::vector<c10::Stream>& streams) const {
|
|
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
|
|
|
|
c10::intrusive_ptr<JitFuture> future;
|
|
if (scriptRemoteCall.hasOp()) {
|
|
future = runJitOperator(
|
|
*scriptRemoteCall.op(), scriptRemoteCall.stackRef(), streams);
|
|
} else {
|
|
future = runJitFunction(
|
|
scriptRemoteCall.qualifiedName(),
|
|
scriptRemoteCall.stackRef(),
|
|
streams,
|
|
scriptRemoteCall.isAsyncExecution());
|
|
}
|
|
|
|
return assignOwnerRRef(
|
|
scriptRemoteCall.retRRefId(), scriptRemoteCall.retForkId(), future);
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
|
|
RpcCommandBase& rpc,
|
|
const std::vector<c10::Stream>& streams) const {
|
|
auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
|
|
auto future =
|
|
runPythonFunction(uprc.pythonUdf(), streams, uprc.isAsyncExecution());
|
|
|
|
return assignOwnerRRef(uprc.rrefId(), uprc.forkId(), future);
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRRefFetchCall(
|
|
RpcCommandBase& rpc) const {
|
|
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
|
|
|
|
auto future = retrieveOwnerRRef(prf.rrefId());
|
|
|
|
return future->then(
|
|
[](JitFuture& future) {
|
|
SerializedPyObj result = serializePyObject(future.value());
|
|
return withStorages(
|
|
PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
|
|
},
|
|
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
|
}
|
|
|
|
void RequestCallbackImpl::handleRRefDelete(
|
|
c10::intrusive_ptr<RRef>& rref) const {
|
|
if (rref && rref->isPyObj()) {
|
|
py::gil_scoped_acquire acquire;
|
|
rref.reset();
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRpcWithErrors(
|
|
RpcCommandBase& rpc,
|
|
const MessageType& messageType,
|
|
const std::vector<c10::Stream>& streams) const {
|
|
try {
|
|
return processRpc(rpc, messageType, streams);
|
|
} catch (py::error_already_set& e) {
|
|
// Pass a dummy message ID since it will be overwritten anyways.
|
|
auto future = asFuture(handleError(e, messageType, -1));
|
|
// There are request callback impls in Python, where Python
|
|
// exceptions could be thrown. For releasing Python exception
|
|
// py::objects, GIL must be held.
|
|
py::gil_scoped_acquire acquire;
|
|
e.restore(); // Release ownership on py::objects and also restore
|
|
// Python Error Indicator.
|
|
PyErr_Clear(); // Clear the Python Error Indicator as we has
|
|
// recorded the exception in the response message.
|
|
return future;
|
|
} catch (std::exception& e) {
|
|
// Pass a dummy message ID since it will be overwritten anyways.
|
|
return asFuture(handleError(e, messageType, -1));
|
|
}
|
|
}
|
|
|
|
bool RequestCallbackImpl::cudaAvailable() const {
|
|
#ifdef USE_CUDA
|
|
return true;
|
|
#else
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRRefBackward(
|
|
RpcCommandBase& rpc) const {
|
|
auto& rrefBackwardReq = static_cast<RRefBackwardReq&>(rpc);
|
|
|
|
auto future = retrieveOwnerRRef(rrefBackwardReq.getRRefId());
|
|
|
|
return future->then(
|
|
[autogradContextId = rrefBackwardReq.getAutogradContextId(),
|
|
retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) {
|
|
// Run backward (TODO: make this async?).
|
|
PyRRef::backwardOwnerRRef(
|
|
autogradContextId, retainGraph, future.value());
|
|
|
|
return withStorages(RRefBackwardResp().toMessage());
|
|
},
|
|
c10::getCustomClassType<c10::intrusive_ptr<Message>>());
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
|
|
const c10::QualifiedName& name,
|
|
std::vector<at::IValue>& stack,
|
|
const std::vector<c10::Stream>& streams,
|
|
bool isAsyncExecution) const {
|
|
c10::MultiStreamGuard guard(streams);
|
|
c10::intrusive_ptr<JitFuture> future;
|
|
try {
|
|
// runAsync() starts in the calling thread, but may return an uncompleted
|
|
// future (though for non-async code, it will typically be completed).
|
|
// If it was async, our callback will typically be invoked by the
|
|
// continuation on an at::launch() thread.
|
|
future = PythonRpcHandler::getInstance()
|
|
.jitCompilationUnit()
|
|
->get_function(name)
|
|
.runAsync(stack);
|
|
} catch (const std::exception&) {
|
|
return asFuture(std::current_exception());
|
|
}
|
|
|
|
if (isAsyncExecution) {
|
|
at::TypePtr type = future->elementType();
|
|
if (type->kind() != at::FutureType::Kind) {
|
|
return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
|
|
"Async functions must return an IValue of Future type, but got ",
|
|
type->str()))));
|
|
}
|
|
future = future->thenAsync(
|
|
[](JitFuture& future) { return future.value().toFuture(); },
|
|
type->cast<at::FutureType>()->getElementType());
|
|
}
|
|
|
|
return future;
|
|
}
|
|
|
|
} // namespace torch::distributed::rpc
|