Simplify the data sharing mechanism, using std::shared_ptr instead of home-brew code.

Also cleaned notebook notes a little bit.
This commit is contained in:
Yangqing Jia
2015-06-25 21:23:23 -07:00
parent 2ed1077a83
commit 9a19430a39
3 changed files with 60 additions and 74 deletions

View File

@ -71,25 +71,23 @@ class Blob {
DISABLE_COPY_AND_ASSIGN(Blob);
};
template <typename dtype, class Context>
class Tensor {
public:
Tensor() : ndim_(0), size_(0), data_(nullptr),
own_data_(true), data_source_(nullptr) {}
Tensor() : ndim_(0), size_(0), data_(nullptr) {}
// Creates a tensor. The actual data allocation is going to be carried out
// till the first time mutable_data() is called, so there is no overhead of
// creating multiple tensors just as placeholders (although I haven't got a
// clear idea where such cases would happen).
explicit Tensor(const vector<int>& dims)
: data_(nullptr), own_data_(true), data_source_(nullptr) {
: data_(nullptr) {
Reshape(dims);
}
template <class SrcContext>
Tensor(const Tensor<dtype, SrcContext>& src, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) {
: data_(nullptr) {
Reshape(src.dims());
context->template Copy<dtype, Context, SrcContext>(
mutable_data(), src.data(), src.size());
@ -98,7 +96,7 @@ class Tensor {
// Creates a tensor, and fills its contents with the given values. We need to
// have a context passed in as the copy function is device dependent.
Tensor(const vector<int>& dims, vector<dtype> values, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) {
: data_(nullptr) {
Reshape(dims);
CHECK_EQ(values.size(), size_);
context->template Copy<dtype, Context, CPUContext>(
@ -107,15 +105,13 @@ class Tensor {
// Special case of above: create a tensor of shape 1, and the given value.
Tensor(const dtype& value, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) {
: data_(nullptr) {
Reshape(std::vector<int>(1, 1));
context->template Copy<dtype, Context, CPUContext>(
mutable_data(), &value, 1);
}
virtual ~Tensor() {
Free();
}
virtual ~Tensor() {}
void Reshape(const vector<int>& dims) {
CHECK_GT(dims.size(), 0);
@ -127,10 +123,10 @@ class Tensor {
CHECK_GT(d, 0);
new_size *= d;
}
// If the size changes, we will call Free(). The next data() call will
// re-allocate the memory.
if (data_ && size_ != new_size) {
Free();
// If the size changes, we will free the data. the next mutable_data() call
// will create the data storage.
if (data_.get() && size_ != new_size) {
data_.reset();
}
size_ = new_size;
}
@ -142,11 +138,19 @@ class Tensor {
void ShareData(const Tensor& src) {
// To share data, the sizes must be equal.
// The reason we do not force the ShareData to have an explicit reshape is
// because we want to allow tensors to have different shape but still
// maintain the same underlying data storage, as long as the contents are
// of the same size.
CHECK_EQ(src.size_, size_)
<< "Size mismatch - did you call reshape before sharing the data?";
if (data_) Free();
own_data_ = false;
data_source_ = &src;
// It is possible that the source tensor hasn't called mutable_data() yet,
// in which case ShareData() does make much sense since we don't really know
// what to share yet.
CHECK(src.data_.get())
<< "Source tensor has no content yet.";
// Finally, do sharing.
data_ = src.data_;
}
inline int ndim() const { return ndim_; }
@ -159,49 +163,26 @@ class Tensor {
}
const dtype* data() const {
if (own_data_) {
CHECK_NOTNULL(data_);
return data_;
} else {
CHECK_NOTNULL(data_source_);
CHECK_EQ(data_source_->size_, size_) << "Source data size has changed.";
CHECK_NOTNULL(data_source_->data());
return data_source_->data();
}
CHECK_NOTNULL(data_.get());
return data_.get();
}
dtype* mutable_data() {
CHECK(own_data_) << "Cannot call mutable_data() from a shared tensor.";
CHECK_GT(size_, 0) << "Cannot call mutable_data on a size 0 tensor.";
if (!data_) Allocate();
CHECK_NOTNULL(data_);
return data_;
if (!data_.get()) Allocate();
return data_.get();
}
void Allocate() {
CHECK(data_ == nullptr);
CHECK_GT(size_, 0);
data_ = static_cast<dtype*>(Context::New(size_ * sizeof(dtype)));
}
void Free() {
if (own_data_) {
if (data_) {
Context::Delete(data_);
}
}
own_data_ = true;
data_ = nullptr;
data_.reset(static_cast<dtype*>(Context::New(size_ * sizeof(dtype))),
Context::Delete);
}
protected:
int ndim_;
vector<int> dims_;
int size_;
dtype* data_;
bool own_data_;
const Tensor* data_source_;
std::shared_ptr<dtype> data_;
DISABLE_COPY_AND_ASSIGN(Tensor);
};

View File

@ -114,8 +114,8 @@ TYPED_TEST(TensorCPUTest, TensorShareData) {
dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data());
@ -135,10 +135,10 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
alternate_dims[0] = 2 * 3 * 5;
Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(alternate_dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_EQ(other_tensor.ndim(), 1);
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data());
@ -149,35 +149,30 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
}
}
TYPED_TEST(TensorCPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
vector<int> dims(3);
dims[0] = 2;
dims[1] = 3;
dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims);
other_tensor.ShareData(tensor);
ASSERT_DEATH(other_tensor.mutable_data(), "");
}
TYPED_TEST(TensorCPUDeathTest, CannotDoReshapewithAlias) {
TYPED_TEST(TensorCPUTest, NoLongerSharesAfterReshape) {
vector<int> dims(3);
dims[0] = 2;
dims[1] = 3;
dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_EQ(tensor.data(), other_tensor.data());
auto* old_pointer = other_tensor.data();
dims[0] = 7;
tensor.Reshape(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
ASSERT_DEATH(other_tensor.data(), ".*Source data size has changed..*");
EXPECT_EQ(old_pointer, other_tensor.data());
EXPECT_NE(old_pointer, tensor.mutable_data());
}
TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) {
Tensor<TypeParam, CPUContext> tensor;
EXPECT_EQ(tensor.ndim(), 0);
ASSERT_DEATH(tensor.data(), ".*Check failed: 'data_' Must be non NULL.*");
ASSERT_DEATH(tensor.data(), "");
}

View File

@ -63,45 +63,55 @@ TYPED_TEST(TensorGPUTest, TensorShareData) {
dims[2] = 5;
Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data());
}
TYPED_TEST(TensorGPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
::testing::FLAGS_gtest_death_test_style = "threadsafe";
TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
vector<int> dims(3);
dims[0] = 2;
dims[1] = 3;
dims[2] = 5;
vector<int> alternate_dims(1);
alternate_dims[0] = 2 * 3 * 5;
Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(alternate_dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
ASSERT_DEATH(other_tensor.mutable_data(), "");
EXPECT_EQ(other_tensor.ndim(), 1);
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data());
}
TYPED_TEST(TensorGPUDeathTest, CannotDoReshapewithAlias) {
::testing::FLAGS_gtest_death_test_style = "threadsafe";
TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
vector<int> dims(3);
dims[0] = 2;
dims[1] = 3;
dims[2] = 5;
Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_EQ(tensor.data(), other_tensor.data());
auto* old_pointer = other_tensor.data();
dims[0] = 7;
tensor.Reshape(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
ASSERT_DEATH(other_tensor.data(), "Source data size has changed.");
EXPECT_EQ(old_pointer, other_tensor.data());
EXPECT_NE(old_pointer, tensor.mutable_data());
}
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
::testing::FLAGS_gtest_death_test_style = "threadsafe";
Tensor<TypeParam, CUDAContext> tensor;
EXPECT_EQ(tensor.ndim(), 0);
ASSERT_DEATH(tensor.data(), "Check failed: 'data_' Must be non NULL");
ASSERT_DEATH(tensor.data(), "");
}
} // namespace caffe2