Replace DeviceIndexes with Devices in RRefs (#57442)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57442

We did this for the RPC agents and for ivalue::Future, the last one (I think) is RRef.
ghstack-source-id: 128184664

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D28144368

fbshipit-source-id: eeacab6006f72118cbec542a02322f2e391c67a3
This commit is contained in:
Luca Wehrstedt
2021-05-06 01:10:10 -07:00
committed by Facebook GitHub Bot
parent 8e9bbd3113
commit 7ffadf6e46
8 changed files with 19 additions and 20 deletions

View File

@ -246,7 +246,7 @@ OwnerRRef::OwnerRRef(
worker_id_t ownerId,
const RRefId& rrefId,
TypePtr type,
std::vector<c10::DeviceIndex> devices)
std::vector<c10::Device> devices)
: OwnerRRef(ownerId, rrefId, type, /* value */ {}, std::move(devices)) {}
OwnerRRef::OwnerRRef(
@ -254,15 +254,10 @@ OwnerRRef::OwnerRRef(
const RRefId& rrefId,
TypePtr type,
c10::optional<IValue> value,
std::vector<c10::DeviceIndex> devices)
std::vector<c10::Device> devices)
: RRef(ownerId, rrefId, type) {
std::vector<c10::Device> fullDevices;
fullDevices.reserve(devices.size());
for (const c10::DeviceIndex& idx : devices) {
fullDevices.emplace_back(c10::kCUDA, idx);
}
future_ = std::make_shared<JitFuture>(
at::AnyClassType::get(), std::move(fullDevices));
future_ =
std::make_shared<JitFuture>(at::AnyClassType::get(), std::move(devices));
if (value.has_value()) {
future_->markCompleted(value.value());