Simplify the data sharing mechanism, using std::shared_ptr instead of home-brew code.

Also cleaned notebook notes a little bit.
This commit is contained in:
Yangqing Jia
2015-06-25 21:23:23 -07:00
parent 2ed1077a83
commit 9a19430a39
3 changed files with 60 additions and 74 deletions

View File

@ -71,25 +71,23 @@ class Blob {
DISABLE_COPY_AND_ASSIGN(Blob); DISABLE_COPY_AND_ASSIGN(Blob);
}; };
template <typename dtype, class Context> template <typename dtype, class Context>
class Tensor { class Tensor {
public: public:
Tensor() : ndim_(0), size_(0), data_(nullptr), Tensor() : ndim_(0), size_(0), data_(nullptr) {}
own_data_(true), data_source_(nullptr) {}
// Creates a tensor. The actual data allocation is going to be carried out // Creates a tensor. The actual data allocation is going to be carried out
// till the first time mutable_data() is called, so there is no overhead of // till the first time mutable_data() is called, so there is no overhead of
// creating multiple tensors just as placeholders (although I haven't got a // creating multiple tensors just as placeholders (although I haven't got a
// clear idea where such cases would happen). // clear idea where such cases would happen).
explicit Tensor(const vector<int>& dims) explicit Tensor(const vector<int>& dims)
: data_(nullptr), own_data_(true), data_source_(nullptr) { : data_(nullptr) {
Reshape(dims); Reshape(dims);
} }
template <class SrcContext> template <class SrcContext>
Tensor(const Tensor<dtype, SrcContext>& src, Context* context) Tensor(const Tensor<dtype, SrcContext>& src, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) { : data_(nullptr) {
Reshape(src.dims()); Reshape(src.dims());
context->template Copy<dtype, Context, SrcContext>( context->template Copy<dtype, Context, SrcContext>(
mutable_data(), src.data(), src.size()); mutable_data(), src.data(), src.size());
@ -98,7 +96,7 @@ class Tensor {
// Creates a tensor, and fills its contents with the given values. We need to // Creates a tensor, and fills its contents with the given values. We need to
// have a context passed in as the copy function is device dependent. // have a context passed in as the copy function is device dependent.
Tensor(const vector<int>& dims, vector<dtype> values, Context* context) Tensor(const vector<int>& dims, vector<dtype> values, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) { : data_(nullptr) {
Reshape(dims); Reshape(dims);
CHECK_EQ(values.size(), size_); CHECK_EQ(values.size(), size_);
context->template Copy<dtype, Context, CPUContext>( context->template Copy<dtype, Context, CPUContext>(
@ -107,15 +105,13 @@ class Tensor {
// Special case of above: create a tensor of shape 1, and the given value. // Special case of above: create a tensor of shape 1, and the given value.
Tensor(const dtype& value, Context* context) Tensor(const dtype& value, Context* context)
: data_(nullptr), own_data_(true), data_source_(nullptr) { : data_(nullptr) {
Reshape(std::vector<int>(1, 1)); Reshape(std::vector<int>(1, 1));
context->template Copy<dtype, Context, CPUContext>( context->template Copy<dtype, Context, CPUContext>(
mutable_data(), &value, 1); mutable_data(), &value, 1);
} }
virtual ~Tensor() { virtual ~Tensor() {}
Free();
}
void Reshape(const vector<int>& dims) { void Reshape(const vector<int>& dims) {
CHECK_GT(dims.size(), 0); CHECK_GT(dims.size(), 0);
@ -127,10 +123,10 @@ class Tensor {
CHECK_GT(d, 0); CHECK_GT(d, 0);
new_size *= d; new_size *= d;
} }
// If the size changes, we will call Free(). The next data() call will // If the size changes, we will free the data. the next mutable_data() call
// re-allocate the memory. // will create the data storage.
if (data_ && size_ != new_size) { if (data_.get() && size_ != new_size) {
Free(); data_.reset();
} }
size_ = new_size; size_ = new_size;
} }
@ -142,11 +138,19 @@ class Tensor {
void ShareData(const Tensor& src) { void ShareData(const Tensor& src) {
// To share data, the sizes must be equal. // To share data, the sizes must be equal.
// The reason we do not force the ShareData to have an explicit reshape is
// because we want to allow tensors to have different shape but still
// maintain the same underlying data storage, as long as the contents are
// of the same size.
CHECK_EQ(src.size_, size_) CHECK_EQ(src.size_, size_)
<< "Size mismatch - did you call reshape before sharing the data?"; << "Size mismatch - did you call reshape before sharing the data?";
if (data_) Free(); // It is possible that the source tensor hasn't called mutable_data() yet,
own_data_ = false; // in which case ShareData() does make much sense since we don't really know
data_source_ = &src; // what to share yet.
CHECK(src.data_.get())
<< "Source tensor has no content yet.";
// Finally, do sharing.
data_ = src.data_;
} }
inline int ndim() const { return ndim_; } inline int ndim() const { return ndim_; }
@ -159,49 +163,26 @@ class Tensor {
} }
const dtype* data() const { const dtype* data() const {
if (own_data_) { CHECK_NOTNULL(data_.get());
CHECK_NOTNULL(data_); return data_.get();
return data_;
} else {
CHECK_NOTNULL(data_source_);
CHECK_EQ(data_source_->size_, size_) << "Source data size has changed.";
CHECK_NOTNULL(data_source_->data());
return data_source_->data();
}
} }
dtype* mutable_data() { dtype* mutable_data() {
CHECK(own_data_) << "Cannot call mutable_data() from a shared tensor."; if (!data_.get()) Allocate();
CHECK_GT(size_, 0) << "Cannot call mutable_data on a size 0 tensor."; return data_.get();
if (!data_) Allocate();
CHECK_NOTNULL(data_);
return data_;
} }
void Allocate() { void Allocate() {
CHECK(data_ == nullptr);
CHECK_GT(size_, 0); CHECK_GT(size_, 0);
data_ = static_cast<dtype*>(Context::New(size_ * sizeof(dtype))); data_.reset(static_cast<dtype*>(Context::New(size_ * sizeof(dtype))),
} Context::Delete);
void Free() {
if (own_data_) {
if (data_) {
Context::Delete(data_);
}
}
own_data_ = true;
data_ = nullptr;
} }
protected: protected:
int ndim_; int ndim_;
vector<int> dims_; vector<int> dims_;
int size_; int size_;
dtype* data_; std::shared_ptr<dtype> data_;
bool own_data_;
const Tensor* data_source_;
DISABLE_COPY_AND_ASSIGN(Tensor); DISABLE_COPY_AND_ASSIGN(Tensor);
}; };

View File

@ -114,8 +114,8 @@ TYPED_TEST(TensorCPUTest, TensorShareData) {
dims[2] = 5; dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims); Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims); Tensor<TypeParam, CPUContext> other_tensor(dims);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.mutable_data() != nullptr); EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data()); EXPECT_EQ(tensor.data(), other_tensor.data());
@ -135,10 +135,10 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
alternate_dims[0] = 2 * 3 * 5; alternate_dims[0] = 2 * 3 * 5;
Tensor<TypeParam, CPUContext> tensor(dims); Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(alternate_dims); Tensor<TypeParam, CPUContext> other_tensor(alternate_dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor); other_tensor.ShareData(tensor);
EXPECT_EQ(other_tensor.ndim(), 1); EXPECT_EQ(other_tensor.ndim(), 1);
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]); EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data()); EXPECT_EQ(tensor.data(), other_tensor.data());
@ -149,35 +149,30 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
} }
} }
TYPED_TEST(TensorCPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
vector<int> dims(3);
dims[0] = 2;
dims[1] = 3;
dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims);
other_tensor.ShareData(tensor);
ASSERT_DEATH(other_tensor.mutable_data(), "");
}
TYPED_TEST(TensorCPUDeathTest, CannotDoReshapewithAlias) { TYPED_TEST(TensorCPUTest, NoLongerSharesAfterReshape) {
vector<int> dims(3); vector<int> dims(3);
dims[0] = 2; dims[0] = 2;
dims[1] = 3; dims[1] = 3;
dims[2] = 5; dims[2] = 5;
Tensor<TypeParam, CPUContext> tensor(dims); Tensor<TypeParam, CPUContext> tensor(dims);
Tensor<TypeParam, CPUContext> other_tensor(dims); Tensor<TypeParam, CPUContext> other_tensor(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor); other_tensor.ShareData(tensor);
EXPECT_EQ(tensor.data(), other_tensor.data());
auto* old_pointer = other_tensor.data();
dims[0] = 7; dims[0] = 7;
tensor.Reshape(dims); tensor.Reshape(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr); EXPECT_EQ(old_pointer, other_tensor.data());
ASSERT_DEATH(other_tensor.data(), ".*Source data size has changed..*"); EXPECT_NE(old_pointer, tensor.mutable_data());
} }
TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) { TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) {
Tensor<TypeParam, CPUContext> tensor; Tensor<TypeParam, CPUContext> tensor;
EXPECT_EQ(tensor.ndim(), 0); EXPECT_EQ(tensor.ndim(), 0);
ASSERT_DEATH(tensor.data(), ".*Check failed: 'data_' Must be non NULL.*"); ASSERT_DEATH(tensor.data(), "");
} }

View File

@ -63,45 +63,55 @@ TYPED_TEST(TensorGPUTest, TensorShareData) {
dims[2] = 5; dims[2] = 5;
Tensor<TypeParam, CUDAContext> tensor(dims); Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims); Tensor<TypeParam, CUDAContext> other_tensor(dims);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.mutable_data() != nullptr); EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor);
EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data()); EXPECT_EQ(tensor.data(), other_tensor.data());
} }
TYPED_TEST(TensorGPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) { TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
::testing::FLAGS_gtest_death_test_style = "threadsafe";
vector<int> dims(3); vector<int> dims(3);
dims[0] = 2; dims[0] = 2;
dims[1] = 3; dims[1] = 3;
dims[2] = 5; dims[2] = 5;
vector<int> alternate_dims(1);
alternate_dims[0] = 2 * 3 * 5;
Tensor<TypeParam, CUDAContext> tensor(dims); Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims); Tensor<TypeParam, CUDAContext> other_tensor(alternate_dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor); other_tensor.ShareData(tensor);
ASSERT_DEATH(other_tensor.mutable_data(), ""); EXPECT_EQ(other_tensor.ndim(), 1);
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
EXPECT_TRUE(tensor.data() != nullptr);
EXPECT_TRUE(other_tensor.data() != nullptr);
EXPECT_EQ(tensor.data(), other_tensor.data());
} }
TYPED_TEST(TensorGPUDeathTest, CannotDoReshapewithAlias) { TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
::testing::FLAGS_gtest_death_test_style = "threadsafe";
vector<int> dims(3); vector<int> dims(3);
dims[0] = 2; dims[0] = 2;
dims[1] = 3; dims[1] = 3;
dims[2] = 5; dims[2] = 5;
Tensor<TypeParam, CUDAContext> tensor(dims); Tensor<TypeParam, CUDAContext> tensor(dims);
Tensor<TypeParam, CUDAContext> other_tensor(dims); Tensor<TypeParam, CUDAContext> other_tensor(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr);
other_tensor.ShareData(tensor); other_tensor.ShareData(tensor);
EXPECT_EQ(tensor.data(), other_tensor.data());
auto* old_pointer = other_tensor.data();
dims[0] = 7; dims[0] = 7;
tensor.Reshape(dims); tensor.Reshape(dims);
EXPECT_TRUE(tensor.mutable_data() != nullptr); EXPECT_EQ(old_pointer, other_tensor.data());
ASSERT_DEATH(other_tensor.data(), "Source data size has changed."); EXPECT_NE(old_pointer, tensor.mutable_data());
} }
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) { TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
::testing::FLAGS_gtest_death_test_style = "threadsafe"; ::testing::FLAGS_gtest_death_test_style = "threadsafe";
Tensor<TypeParam, CUDAContext> tensor; Tensor<TypeParam, CUDAContext> tensor;
EXPECT_EQ(tensor.ndim(), 0); EXPECT_EQ(tensor.ndim(), 0);
ASSERT_DEATH(tensor.data(), "Check failed: 'data_' Must be non NULL"); ASSERT_DEATH(tensor.data(), "");
} }
} // namespace caffe2 } // namespace caffe2