CUDA RPC Meta device support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77400

Approved by: https://github.com/jamesr66a
This commit is contained in:
Pavel Belevich
2022-05-12 22:45:09 -04:00
committed by PyTorch MergeBot
parent 1136965aa1
commit 058be5f162
4 changed files with 114 additions and 72 deletions

View File

@ -328,8 +328,13 @@ c10::intrusive_ptr<Message> tensorpipeDeserialize(
tensors.emplace_back(std::move(t));
}
for (const auto i : c10::irange(tpDescriptor.tensors.size())) {
auto& tensor = tpDescriptor.tensors[i];
int metaTensorsCounter = 0;
for (const auto i : c10::irange(tensors.size())) {
if (tensors[i].is_meta()) {
metaTensorsCounter++;
continue;
}
auto& tensor = tpDescriptor.tensors[i - metaTensorsCounter];
if (tensor.targetDevice.has_value() &&
tensor.targetDevice->type == tensorpipe::kCudaDeviceType) {
TORCH_INTERNAL_ASSERT(