Files
pytorch/torch/csrc/distributed/rpc/script_remote_call.cpp
Yuanyuan Chen e1e8491b31 [1/N] Change C-style casts to static_cast or reinterpret_cast (#165750)
This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165750
Approved by: https://github.com/Skylion007
2025-10-20 04:36:19 +00:00

82 lines
2.6 KiB
C++

#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch::distributed::rpc {
ScriptRemoteCall::ScriptRemoteCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId)
: ScriptCall(std::move(op), std::move(stack)),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
ScriptRemoteCall::ScriptRemoteCall(
const c10::QualifiedName& qualifiedName,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId,
const bool isAsyncExecution)
: ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
std::vector<at::IValue>& ivalues) {
// remove the last element from values and convert it back to an RRef
auto retForkId = RRefId::fromIValue(ivalues.back());
ivalues.pop_back();
auto retRRefId = ForkId::fromIValue(ivalues.back());
ivalues.pop_back();
auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
if (scriptCallPtr->hasOp()) {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
} else {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->qualifiedName(),
std::move(ivalues),
retRRefId,
retForkId,
scriptCallPtr->isAsyncExecution());
}
}
c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
std::vector<IValue> ivalues;
ScriptCall::toIValues(ivalues);
ivalues.emplace_back(retRRefId_.toIValue());
ivalues.emplace_back(retForkId_.toIValue());
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_REMOTE_CALL);
}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
const Message& message) {
auto payload = message.payload().data();
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTupleRef().elements().vec();
TORCH_CHECK(!values.empty(), "Malformed message: empty values unpickled");
return fromIValues(values);
}
} // namespace torch::distributed::rpc