support torch script call over rpc (#32197)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32197

This is to reland https://github.com/pytorch/pytorch/pull/30063, the main change is to match a general exception and grep "pickle" error word in "test_script_functions_not_supported" unit test, as Python 3.5 and Python 3.6 throw different types of errors with different error message for the rpc call in the unit test.
[test all]This diff makes following changes:
1. Providing a new set of python rpc privated APIs, they can accept an annotated TorchScript call and this call can be serialized, deserialized and executed in C++ without GIL. These privated APIs will be binded to JIT in the future, and they are different from public APIs as future JIT binded private APIs will be able to accept qualified_name, not callables. These private APIs are subject to be deprecated once JIT supports torch script function to be a JIT type.

Also, these APIs require torch script function to be defined and annotated by users in python land, it can not be script class/module constructor or class/module methods.

2. This diff also allows public rpc APIs to accept an annotated TorchScript call and execute code path that above private APIs ran on. Therefore if users invoke an annotated TorchScript call over RPC, this call can be serialized, deserialized and executed in C++ without GIL as well.

3. The above private APIs call a newly defined C++ function to make rpc torch script call to be serialized, deserialized and executed in C++ land. This C++ function returns an ivalue::Future. so that in follow up diff this C++ function can be called when these privated APIs are binded to JIT.

4. script_call.cpp/.h and request_callback_impl.cpp files are refactored accordingly so that torch script call and builtin call can share same message type and codes.

5. refactored deserializeResponse() and added a new utility to deserizalize response to IValue

ghstack-source-id: 96879167
ghstack-source-id: 96879167

Test Plan: unit test

Differential Revision: D19402374

fbshipit-source-id: 04efcc7c167d08a6503f29efe55e76f2be4b2c5e
This commit is contained in:
Yanli Zhao
2020-01-18 09:22:30 -08:00
committed by Facebook Github Bot
parent 1ecad2bb2b
commit 58234c0254
21 changed files with 562 additions and 63 deletions

View File

@ -22,27 +22,6 @@ constexpr int PARENT_IDX = 5; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple
template <typename T>
T& unwrapAutogradMessage(
const Message& message,
std::unique_ptr<RpcCommandBase>& response) {
if (message.type() == MessageType::FORWARD_AUTOGRAD_RESP) {
auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(*response);
// Attach 'recv' autograd function.
addRecvRpcBackward(
rpcWithAutograd.autogradMetadata(),
rpcWithAutograd.tensors(),
rpcWithAutograd.fromWorkerId());
auto& wrappedRpc = rpcWithAutograd.wrappedRpc();
return static_cast<T&>(wrappedRpc);
} else {
return static_cast<T&>(*response);
}
}
} // namespace
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
@ -142,8 +121,13 @@ IValue UserRRef<IValue>::toHere() {
true /* forceGradRecording */);
const Message& message = futureResponse->wait();
auto response = deserializeResponse(message);
auto& rfr = unwrapAutogradMessage<ScriptRRefFetchRet>(message, response);
MessageType msgType = message.type();
auto response = deserializeResponse(message, msgType);
TORCH_INTERNAL_ASSERT(
msgType == MessageType::SCRIPT_RREF_FETCH_RET,
"Message type should be SCRIPT_RREF_FETCH_RET.");
RpcCommandBase& rpc = *response;
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
return rfr.values().front();
}
@ -161,8 +145,13 @@ py::object UserRRef<py::object>::toHere() {
true /* forceGradRecording */);
const Message& message = futureResponse->wait();
auto response = deserializeResponse(message);
auto& rfr = unwrapAutogradMessage<PythonRRefFetchRet>(message, response);
MessageType msgType = message.type();
auto response = deserializeResponse(message, msgType);
TORCH_INTERNAL_ASSERT(
msgType == MessageType::PYTHON_RREF_FETCH_RET,
"Message type should be PYTHON_RREF_FETCH_RET.");
RpcCommandBase& rpc = *response;
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr.values()));
}