Files
pytorch/torch/csrc/distributed/rpc/message.cpp
Peter Bell 40d1f77384 Codegen: python_torch_functions only include relevant operators (#68693)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68693

Generation of python bindings for native functions is split over 8
different files. One for each namespace, with the torch namespace
split into 3 shards, and methods in their own file as well. This
change ensures that editing any single (non-method) operator only
causes one of these files to be rebuilt.

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D32596270

Pulled By: albanD

fbshipit-source-id: 0570ec69e7476b8f1bc21138ba18fe8f95ebbe3f
(cherry picked from commit ba0fc71a3a6835e49b332a8be52bf798fa2726b3)
2022-01-21 15:37:06 +00:00

119 lines
3.0 KiB
C++

#include <torch/csrc/distributed/rpc/message.h>
#include <torch/custom_class.h>
namespace torch {
namespace distributed {
namespace rpc {
Message::Message() = default;
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type)
: payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id)
: payload_(std::move(payload)),
tensors_(std::move(tensors)),
type_(type),
id_(id) {}
std::vector<char>&& Message::movePayload() && {
return std::move(payload_);
}
std::vector<char>& Message::payload() {
return payload_;
}
const std::vector<char>& Message::payload() const {
return payload_;
}
std::vector<torch::Tensor>&& Message::moveTensors() && {
return std::move(tensors_);
}
std::vector<torch::Tensor>& Message::tensors() {
return tensors_;
}
const std::vector<torch::Tensor>& Message::tensors() const {
return tensors_;
}
MessageType Message::type() const {
return type_;
}
bool Message::isRequest() const {
return MessageTypeFlags::REQUEST_TYPE & type_;
}
bool Message::isResponse() const {
return MessageTypeFlags::RESPONSE_TYPE & type_;
}
int64_t Message::id() const {
return id_;
}
void Message::setId(int64_t id) {
id_ = id;
}
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> Message::getStorages()
const {
// Sparse tensors do not have storage. Instead, a sparse tensor
// contains two tensors indices and values, and both contain storage.
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages;
storages.reserve(2 * tensors_.size());
for (const auto& tensor : tensors_) {
if (tensor.is_sparse()) {
storages.emplace_back(tensor._indices().storage().getWeakStorageImpl());
storages.emplace_back(tensor._values().storage().getWeakStorageImpl());
} else {
storages.emplace_back(tensor.storage().getWeakStorageImpl());
}
}
return storages;
}
c10::intrusive_ptr<Message> createExceptionResponse(
const std::exception& e,
int64_t id) {
return createExceptionResponse(e.what(), id);
}
c10::intrusive_ptr<Message> createExceptionResponse(
const std::string& exceptionStr,
int64_t id) {
std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
return c10::make_intrusive<Message>(
std::move(payload),
std::vector<torch::Tensor>(),
MessageType::EXCEPTION,
id);
}
namespace {
// NB: need to call torch::class_ to register Message in the map returned by
// c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
// an IValue.
// NB: add this line here instead of in rpc/init.cpp because 1) we have C++
// only tests that won't run rpc/init.cpp; 2) Message is not meant to be
// visible from Python.
static const auto message = torch::class_<Message>("rpc", "_Message");
} // namespace
} // namespace rpc
} // namespace distributed
} // namespace torch