mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
CUDA RPC Meta device support
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77400 Approved by: https://github.com/jamesr66a
This commit is contained in:
committed by
PyTorch MergeBot
parent
1136965aa1
commit
058be5f162
@ -328,8 +328,13 @@ c10::intrusive_ptr<Message> tensorpipeDeserialize(
|
||||
tensors.emplace_back(std::move(t));
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(tpDescriptor.tensors.size())) {
|
||||
auto& tensor = tpDescriptor.tensors[i];
|
||||
int metaTensorsCounter = 0;
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
if (tensors[i].is_meta()) {
|
||||
metaTensorsCounter++;
|
||||
continue;
|
||||
}
|
||||
auto& tensor = tpDescriptor.tensors[i - metaTensorsCounter];
|
||||
if (tensor.targetDevice.has_value() &&
|
||||
tensor.targetDevice->type == tensorpipe::kCudaDeviceType) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
Reference in New Issue
Block a user