mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use type-erased union for Buffer. (#54251)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54251
Pull Request resolved: https://github.com/pytorch/tensorpipe/pull/324
In order to merge the channel hierarchies, we need a generic `Buffer` type, that can wrap either a `CpuBuffer` or a `CudaBuffer`.
The constraints are that, since this type is used by the channels, it cannot explicitly refer to `CudaBuffer`. We propose here a type-erasure based solution, with small-buffer optimization to avoid heap-allocating the wrapped concrete buffer.
This is a new version of D27001339 (c618dc13d2
) which broke PyTorch OSS build.
Test Plan: CI
Reviewed By: lw, mrshenli
Differential Revision: D27156053
fbshipit-source-id: 4244302af33a3be91dcd06093c0d6045d081d3cc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8f755b9ed0
commit
a84afb3a7c
@ -1,5 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <tensorpipe/common/cpu_buffer.h>
|
||||
#include <tensorpipe/core/message.h>
|
||||
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
|
||||
#include <torch/torch.h>
|
||||
@ -42,7 +43,8 @@ TEST(TensorpipeSerialize, Base) {
|
||||
recvingTpMessage.tensors.reserve(sendingTpMessage.tensors.size());
|
||||
for (auto& tpTensor : sendingTpMessage.tensors) {
|
||||
tensorpipe::Message::Tensor t;
|
||||
t.buffer = tensorpipe::CpuBuffer{nullptr, tpTensor.buffer.cpu.length};
|
||||
t.buffer = tensorpipe::CpuBuffer{
|
||||
nullptr, tpTensor.buffer.unwrap<tensorpipe::CpuBuffer>().length};
|
||||
t.metadata = tpTensor.metadata;
|
||||
recvingTpMessage.tensors.push_back(std::move(t));
|
||||
}
|
||||
@ -68,9 +70,9 @@ TEST(TensorpipeSerialize, Base) {
|
||||
tensorpipe::Message::Tensor& srcTensor = sendingTpMessage.tensors[i];
|
||||
tensorpipe::Message::Tensor& dstTensor = recvingTpMessage.tensors[i];
|
||||
memcpy(
|
||||
dstTensor.buffer.cpu.ptr,
|
||||
srcTensor.buffer.cpu.ptr,
|
||||
srcTensor.buffer.cpu.length);
|
||||
dstTensor.buffer.unwrap<tensorpipe::CpuBuffer>().ptr,
|
||||
srcTensor.buffer.unwrap<tensorpipe::CpuBuffer>().ptr,
|
||||
srcTensor.buffer.unwrap<tensorpipe::CpuBuffer>().length);
|
||||
}
|
||||
|
||||
// Mimic read() callback:
|
||||
@ -113,10 +115,17 @@ TEST(TensorpipeSerialize, RecopySparseTensors) {
|
||||
EXPECT_TRUE(torch::equal(main, tpBuffers.tensors[0]));
|
||||
EXPECT_TRUE(torch::equal(tiny, tpBuffers.tensors[1]));
|
||||
// Test cloned storage
|
||||
EXPECT_EQ(main.storage().data(), sendingTpMessage.tensors[0].buffer.cpu.ptr);
|
||||
EXPECT_NE(tiny.storage().data(), sendingTpMessage.tensors[1].buffer.cpu.ptr);
|
||||
EXPECT_EQ(
|
||||
tiny.element_size() * k1K, sendingTpMessage.tensors[1].buffer.cpu.length);
|
||||
main.storage().data(),
|
||||
sendingTpMessage.tensors[0].buffer.unwrap<tensorpipe::CpuBuffer>().ptr);
|
||||
EXPECT_NE(
|
||||
tiny.storage().data(),
|
||||
sendingTpMessage.tensors[1].buffer.unwrap<tensorpipe::CpuBuffer>().ptr);
|
||||
EXPECT_EQ(
|
||||
tiny.element_size() * k1K,
|
||||
sendingTpMessage.tensors[1]
|
||||
.buffer.unwrap<tensorpipe::CpuBuffer>()
|
||||
.length);
|
||||
}
|
||||
|
||||
TEST(TensorpipeSerialize, NoDeleterTensors) {
|
||||
@ -141,24 +150,32 @@ TEST(TensorpipeSerialize, NoDeleterTensors) {
|
||||
EXPECT_EQ(sendingTpMessage.tensors.size(), 2);
|
||||
EXPECT_EQ(
|
||||
tpBuffers.copiedTensors[0].size(),
|
||||
sendingTpMessage.tensors[0].buffer.cpu.length);
|
||||
sendingTpMessage.tensors[0]
|
||||
.buffer.unwrap<tensorpipe::CpuBuffer>()
|
||||
.length);
|
||||
EXPECT_EQ(
|
||||
tpBuffers.copiedTensors[1].size(),
|
||||
sendingTpMessage.tensors[1].buffer.cpu.length);
|
||||
sendingTpMessage.tensors[1]
|
||||
.buffer.unwrap<tensorpipe::CpuBuffer>()
|
||||
.length);
|
||||
EXPECT_EQ(
|
||||
tpBuffers.copiedTensors[0].data(),
|
||||
sendingTpMessage.tensors[0].buffer.cpu.ptr);
|
||||
sendingTpMessage.tensors[0].buffer.unwrap<tensorpipe::CpuBuffer>().ptr);
|
||||
EXPECT_EQ(
|
||||
tpBuffers.copiedTensors[1].data(),
|
||||
sendingTpMessage.tensors[1].buffer.cpu.ptr);
|
||||
sendingTpMessage.tensors[1].buffer.unwrap<tensorpipe::CpuBuffer>().ptr);
|
||||
EXPECT_TRUE(
|
||||
memcmp(
|
||||
tpBuffers.copiedTensors[0].data(),
|
||||
t1.storage().data(),
|
||||
sendingTpMessage.tensors[0].buffer.cpu.length) == 0);
|
||||
sendingTpMessage.tensors[0]
|
||||
.buffer.unwrap<tensorpipe::CpuBuffer>()
|
||||
.length) == 0);
|
||||
EXPECT_TRUE(
|
||||
memcmp(
|
||||
tpBuffers.copiedTensors[1].data(),
|
||||
t2.storage().data(),
|
||||
sendingTpMessage.tensors[1].buffer.cpu.length) == 0);
|
||||
sendingTpMessage.tensors[1]
|
||||
.buffer.unwrap<tensorpipe::CpuBuffer>()
|
||||
.length) == 0);
|
||||
}
|
||||
|
2
third_party/tensorpipe
vendored
2
third_party/tensorpipe
vendored
Submodule third_party/tensorpipe updated: c54fdda499...84937de943
@ -5,10 +5,9 @@
|
||||
#ifdef USE_CUDA_NOT_ROCM
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <tensorpipe/tensorpipe.h>
|
||||
#endif
|
||||
|
||||
#include <tensorpipe/core/message.h>
|
||||
#include <tensorpipe/tensorpipe.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
@ -173,12 +172,13 @@ TensorpipeReadBuffers tensorpipeAllocate(
|
||||
tpMessage.payloads[kTpMessagePickleIdx].data = buffers.pickle.data();
|
||||
|
||||
for (auto& tensor : tpMessage.tensors) {
|
||||
if (tensor.buffer.type == tensorpipe::DeviceType::kCpu) {
|
||||
buffers.tensors.emplace_back(
|
||||
at::getCPUAllocator()->allocate(tensor.buffer.cpu.length));
|
||||
tensor.buffer.cpu.ptr = buffers.tensors.back().get();
|
||||
if (tensor.buffer.deviceType() == tensorpipe::DeviceType::kCpu) {
|
||||
buffers.tensors.emplace_back(at::getCPUAllocator()->allocate(
|
||||
tensor.buffer.unwrap<tensorpipe::CpuBuffer>().length));
|
||||
tensor.buffer.unwrap<tensorpipe::CpuBuffer>().ptr =
|
||||
buffers.tensors.back().get();
|
||||
#ifdef USE_CUDA_NOT_ROCM
|
||||
} else if (tensor.buffer.type == tensorpipe::DeviceType::kCuda) {
|
||||
} else if (tensor.buffer.deviceType() == tensorpipe::DeviceType::kCuda) {
|
||||
auto deviceIndex = std::stoi(tensor.metadata);
|
||||
auto stream = ctx->getStream(deviceIndex);
|
||||
// CUDACachingAllocator will call recordStream accordingly on the current
|
||||
@ -186,9 +186,10 @@ TensorpipeReadBuffers tensorpipeAllocate(
|
||||
at::cuda::CUDAStreamGuard guard(stream);
|
||||
buffers.tensors.emplace_back(
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(
|
||||
tensor.buffer.cuda.length));
|
||||
tensor.buffer.cuda.ptr = buffers.tensors.back().get();
|
||||
tensor.buffer.cuda.stream = stream.stream();
|
||||
tensor.buffer.unwrap<tensorpipe::CudaBuffer>().length));
|
||||
tensor.buffer.unwrap<tensorpipe::CudaBuffer>().ptr =
|
||||
buffers.tensors.back().get();
|
||||
tensor.buffer.unwrap<tensorpipe::CudaBuffer>().stream = stream.stream();
|
||||
#endif
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, "Unrecognized TensorPipe buffer type.");
|
||||
|
Reference in New Issue
Block a user