[rpc] Remove template on RRef and add Type to RRef creation (#30630)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630

This remove template and all the specializations it have in rpc, we
universally use IValue as the inner value since we support making python
object to be hold inside IValue.

This will also ensure that we have the correct type information when
creating the RRef, we use the return type from the schema when creating
userRRef and OwnerRRef, it will enable IValue to always have the correct
type if the IValue is the RRef object (next PR)

Test Plan: Imported from OSS

Differential Revision: D19502235

fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
This commit is contained in:
Yanli Zhao
2020-01-23 21:09:23 -08:00
committed by Facebook Github Bot
parent ef2d4e67d1
commit b474c351dd
9 changed files with 219 additions and 238 deletions

View File

@ -1,5 +1,6 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <c10/util/Optional.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
@ -15,7 +16,6 @@ namespace rpc {
class RRef;
class RRefContext;
template <typename T>
class UserRRef;
// Represents fork of an RRef to be sent over the wire.
@ -27,24 +27,21 @@ struct RRefForkData {
const RRefId rrefId_;
const ForkId forkId_;
const worker_id_t parent_;
const std::string type_str_;
private:
friend class RRef;
friend class RRefContext;
template <typename T>
friend class UserRRef;
RRefForkData(
worker_id_t ownerId,
const RRefId& rrefId_,
const ForkId& forkId_,
worker_id_t parent);
worker_id_t parent,
std::string type_str);
};
static_assert(
C10_IS_TRIVIALLY_COPYABLE(RRefForkData),
"RRefForkData must be trivially copyable");
// Note [RRef Protocol]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
//
@ -207,25 +204,32 @@ class RRef : public RRefInterface {
return rrefId_;
}
// returns true if this RRef holds an py::object, false if IValue
virtual bool isPyObj() = 0;
inline bool isPyObj() {
return type_ == PyObjectType::get();
}
inline const TypePtr type() {
return type_;
}
protected:
friend class RRefContext;
RRef(worker_id_t ownerId, const RRefId& rrefId);
RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
RRefForkData fork() const;
const worker_id_t ownerId_;
const RRefId rrefId_;
// type field to denote the type of the element that the RRef is holding
// it could be any TypePtr that JIT support, including PyObjectType
const TypePtr type_;
};
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
// never owns the real value, the only way to get the value of the ``RRef`` is
// to call ``to_here()`` and get a copy..
template <typename T>
class UserRRef final : public RRef {
public:
UserRRef(const UserRRef& other) = delete;
@ -237,16 +241,12 @@ class UserRRef final : public RRef {
return false;
}
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
// Returns the globally unique ForkId of this RRef
const ForkId& forkId() const;
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
// yet, this call will block.
T toHere();
IValue toHere();
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
~UserRRef() override;
@ -254,14 +254,17 @@ class UserRRef final : public RRef {
private:
friend class RRefContext;
UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId);
UserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId,
TypePtr type);
const ForkId forkId_;
};
// Keep the template only on the derived class because ``RRefContext`` needs to
// erase the type on ``RRef`` and keep them in one map.
template <typename T>
class OwnerRRef final : public RRef {
public:
OwnerRRef(const OwnerRRef& other) = delete;
@ -273,18 +276,14 @@ class OwnerRRef final : public RRef {
return true;
}
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
// Get a constant reference of the real value. This method will block if the
// value is not ready. This method does not need GIL as it does not create
// any new py::object.
const T& getValue() const;
const IValue& getValue() const;
// Set the value of this ``OwnerRRef``. This method does not need GIL as it
// does not create any new py::object.
void setValue(T&& value);
void setValue(IValue&& value);
// Has a value been set?
bool hasValue() const;
@ -294,15 +293,19 @@ class OwnerRRef final : public RRef {
private:
friend class RRefContext;
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId)
: OwnerRRef(ownerId, rrefId, {}) {}
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
: OwnerRRef(ownerId, rrefId, type, {}) {}
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional<T> value)
: RRef(ownerId, rrefId) {
OwnerRRef(
worker_id_t ownerId,
const RRefId& rrefId,
TypePtr type,
c10::optional<IValue> value)
: RRef(ownerId, rrefId, std::move(type)) {
value_ = std::move(value);
}
c10::optional<T> value_;
c10::optional<IValue> value_;
mutable std::mutex mutex_;
mutable std::condition_variable valueCV_;
std::shared_ptr<FutureMessage> future_;