mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This series of changes try to cover C style casts into C++ alternatives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750 Approved by: https://github.com/Skylion007
82 lines
2.6 KiB
C++
82 lines
2.6 KiB
C++
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
|
|
#include <torch/csrc/jit/serialization/pickle.h>
|
|
|
|
namespace torch::distributed::rpc {
|
|
|
|
ScriptRemoteCall::ScriptRemoteCall(
|
|
std::shared_ptr<Operator> op,
|
|
std::vector<at::IValue>&& stack,
|
|
const RRefId& retRRefId,
|
|
const ForkId& retForkId)
|
|
: ScriptCall(std::move(op), std::move(stack)),
|
|
retRRefId_(retRRefId),
|
|
retForkId_(retForkId) {}
|
|
|
|
ScriptRemoteCall::ScriptRemoteCall(
|
|
const c10::QualifiedName& qualifiedName,
|
|
std::vector<at::IValue>&& stack,
|
|
const RRefId& retRRefId,
|
|
const ForkId& retForkId,
|
|
const bool isAsyncExecution)
|
|
: ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
|
|
retRRefId_(retRRefId),
|
|
retForkId_(retForkId) {}
|
|
|
|
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
|
|
std::vector<at::IValue>& ivalues) {
|
|
// remove the last element from values and convert it back to an RRef
|
|
auto retForkId = RRefId::fromIValue(ivalues.back());
|
|
ivalues.pop_back();
|
|
auto retRRefId = ForkId::fromIValue(ivalues.back());
|
|
ivalues.pop_back();
|
|
|
|
auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
|
|
|
|
if (scriptCallPtr->hasOp()) {
|
|
return std::make_unique<ScriptRemoteCall>(
|
|
scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
|
|
} else {
|
|
return std::make_unique<ScriptRemoteCall>(
|
|
scriptCallPtr->qualifiedName(),
|
|
std::move(ivalues),
|
|
retRRefId,
|
|
retForkId,
|
|
scriptCallPtr->isAsyncExecution());
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
|
|
std::vector<IValue> ivalues;
|
|
ScriptCall::toIValues(ivalues);
|
|
ivalues.emplace_back(retRRefId_.toIValue());
|
|
ivalues.emplace_back(retForkId_.toIValue());
|
|
|
|
std::vector<torch::Tensor> tensor_table;
|
|
auto payload = jit::pickle(
|
|
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
|
|
|
|
return c10::make_intrusive<Message>(
|
|
std::move(payload),
|
|
std::move(tensor_table),
|
|
MessageType::SCRIPT_REMOTE_CALL);
|
|
}
|
|
|
|
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
|
|
const Message& message) {
|
|
auto payload = message.payload().data();
|
|
auto payload_size = message.payload().size();
|
|
|
|
auto value = jit::unpickle(
|
|
payload,
|
|
payload_size,
|
|
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
|
message.tensors());
|
|
auto values = value.toTupleRef().elements().vec();
|
|
TORCH_CHECK(!values.empty(), "Malformed message: empty values unpickled");
|
|
return fromIValues(values);
|
|
}
|
|
|
|
} // namespace torch::distributed::rpc
|