mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify the data sharing mechanism, using std::shared_ptr instead of home-brew code.
Also cleaned notebook notes a little bit.
This commit is contained in:
@ -71,25 +71,23 @@ class Blob {
|
|||||||
DISABLE_COPY_AND_ASSIGN(Blob);
|
DISABLE_COPY_AND_ASSIGN(Blob);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename dtype, class Context>
|
template <typename dtype, class Context>
|
||||||
class Tensor {
|
class Tensor {
|
||||||
public:
|
public:
|
||||||
Tensor() : ndim_(0), size_(0), data_(nullptr),
|
Tensor() : ndim_(0), size_(0), data_(nullptr) {}
|
||||||
own_data_(true), data_source_(nullptr) {}
|
|
||||||
|
|
||||||
// Creates a tensor. The actual data allocation is going to be carried out
|
// Creates a tensor. The actual data allocation is going to be carried out
|
||||||
// till the first time mutable_data() is called, so there is no overhead of
|
// till the first time mutable_data() is called, so there is no overhead of
|
||||||
// creating multiple tensors just as placeholders (although I haven't got a
|
// creating multiple tensors just as placeholders (although I haven't got a
|
||||||
// clear idea where such cases would happen).
|
// clear idea where such cases would happen).
|
||||||
explicit Tensor(const vector<int>& dims)
|
explicit Tensor(const vector<int>& dims)
|
||||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
: data_(nullptr) {
|
||||||
Reshape(dims);
|
Reshape(dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class SrcContext>
|
template <class SrcContext>
|
||||||
Tensor(const Tensor<dtype, SrcContext>& src, Context* context)
|
Tensor(const Tensor<dtype, SrcContext>& src, Context* context)
|
||||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
: data_(nullptr) {
|
||||||
Reshape(src.dims());
|
Reshape(src.dims());
|
||||||
context->template Copy<dtype, Context, SrcContext>(
|
context->template Copy<dtype, Context, SrcContext>(
|
||||||
mutable_data(), src.data(), src.size());
|
mutable_data(), src.data(), src.size());
|
||||||
@ -98,7 +96,7 @@ class Tensor {
|
|||||||
// Creates a tensor, and fills its contents with the given values. We need to
|
// Creates a tensor, and fills its contents with the given values. We need to
|
||||||
// have a context passed in as the copy function is device dependent.
|
// have a context passed in as the copy function is device dependent.
|
||||||
Tensor(const vector<int>& dims, vector<dtype> values, Context* context)
|
Tensor(const vector<int>& dims, vector<dtype> values, Context* context)
|
||||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
: data_(nullptr) {
|
||||||
Reshape(dims);
|
Reshape(dims);
|
||||||
CHECK_EQ(values.size(), size_);
|
CHECK_EQ(values.size(), size_);
|
||||||
context->template Copy<dtype, Context, CPUContext>(
|
context->template Copy<dtype, Context, CPUContext>(
|
||||||
@ -107,15 +105,13 @@ class Tensor {
|
|||||||
|
|
||||||
// Special case of above: create a tensor of shape 1, and the given value.
|
// Special case of above: create a tensor of shape 1, and the given value.
|
||||||
Tensor(const dtype& value, Context* context)
|
Tensor(const dtype& value, Context* context)
|
||||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
: data_(nullptr) {
|
||||||
Reshape(std::vector<int>(1, 1));
|
Reshape(std::vector<int>(1, 1));
|
||||||
context->template Copy<dtype, Context, CPUContext>(
|
context->template Copy<dtype, Context, CPUContext>(
|
||||||
mutable_data(), &value, 1);
|
mutable_data(), &value, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual ~Tensor() {
|
virtual ~Tensor() {}
|
||||||
Free();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reshape(const vector<int>& dims) {
|
void Reshape(const vector<int>& dims) {
|
||||||
CHECK_GT(dims.size(), 0);
|
CHECK_GT(dims.size(), 0);
|
||||||
@ -127,10 +123,10 @@ class Tensor {
|
|||||||
CHECK_GT(d, 0);
|
CHECK_GT(d, 0);
|
||||||
new_size *= d;
|
new_size *= d;
|
||||||
}
|
}
|
||||||
// If the size changes, we will call Free(). The next data() call will
|
// If the size changes, we will free the data. the next mutable_data() call
|
||||||
// re-allocate the memory.
|
// will create the data storage.
|
||||||
if (data_ && size_ != new_size) {
|
if (data_.get() && size_ != new_size) {
|
||||||
Free();
|
data_.reset();
|
||||||
}
|
}
|
||||||
size_ = new_size;
|
size_ = new_size;
|
||||||
}
|
}
|
||||||
@ -142,11 +138,19 @@ class Tensor {
|
|||||||
|
|
||||||
void ShareData(const Tensor& src) {
|
void ShareData(const Tensor& src) {
|
||||||
// To share data, the sizes must be equal.
|
// To share data, the sizes must be equal.
|
||||||
|
// The reason we do not force the ShareData to have an explicit reshape is
|
||||||
|
// because we want to allow tensors to have different shape but still
|
||||||
|
// maintain the same underlying data storage, as long as the contents are
|
||||||
|
// of the same size.
|
||||||
CHECK_EQ(src.size_, size_)
|
CHECK_EQ(src.size_, size_)
|
||||||
<< "Size mismatch - did you call reshape before sharing the data?";
|
<< "Size mismatch - did you call reshape before sharing the data?";
|
||||||
if (data_) Free();
|
// It is possible that the source tensor hasn't called mutable_data() yet,
|
||||||
own_data_ = false;
|
// in which case ShareData() does make much sense since we don't really know
|
||||||
data_source_ = &src;
|
// what to share yet.
|
||||||
|
CHECK(src.data_.get())
|
||||||
|
<< "Source tensor has no content yet.";
|
||||||
|
// Finally, do sharing.
|
||||||
|
data_ = src.data_;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int ndim() const { return ndim_; }
|
inline int ndim() const { return ndim_; }
|
||||||
@ -159,49 +163,26 @@ class Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const dtype* data() const {
|
const dtype* data() const {
|
||||||
if (own_data_) {
|
CHECK_NOTNULL(data_.get());
|
||||||
CHECK_NOTNULL(data_);
|
return data_.get();
|
||||||
return data_;
|
|
||||||
} else {
|
|
||||||
CHECK_NOTNULL(data_source_);
|
|
||||||
CHECK_EQ(data_source_->size_, size_) << "Source data size has changed.";
|
|
||||||
CHECK_NOTNULL(data_source_->data());
|
|
||||||
return data_source_->data();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dtype* mutable_data() {
|
dtype* mutable_data() {
|
||||||
CHECK(own_data_) << "Cannot call mutable_data() from a shared tensor.";
|
if (!data_.get()) Allocate();
|
||||||
CHECK_GT(size_, 0) << "Cannot call mutable_data on a size 0 tensor.";
|
return data_.get();
|
||||||
if (!data_) Allocate();
|
|
||||||
CHECK_NOTNULL(data_);
|
|
||||||
return data_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Allocate() {
|
void Allocate() {
|
||||||
CHECK(data_ == nullptr);
|
|
||||||
CHECK_GT(size_, 0);
|
CHECK_GT(size_, 0);
|
||||||
data_ = static_cast<dtype*>(Context::New(size_ * sizeof(dtype)));
|
data_.reset(static_cast<dtype*>(Context::New(size_ * sizeof(dtype))),
|
||||||
}
|
Context::Delete);
|
||||||
|
|
||||||
void Free() {
|
|
||||||
if (own_data_) {
|
|
||||||
if (data_) {
|
|
||||||
Context::Delete(data_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
own_data_ = true;
|
|
||||||
data_ = nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int ndim_;
|
int ndim_;
|
||||||
vector<int> dims_;
|
vector<int> dims_;
|
||||||
int size_;
|
int size_;
|
||||||
dtype* data_;
|
std::shared_ptr<dtype> data_;
|
||||||
bool own_data_;
|
|
||||||
const Tensor* data_source_;
|
|
||||||
|
|
||||||
DISABLE_COPY_AND_ASSIGN(Tensor);
|
DISABLE_COPY_AND_ASSIGN(Tensor);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -114,8 +114,8 @@ TYPED_TEST(TensorCPUTest, TensorShareData) {
|
|||||||
dims[2] = 5;
|
dims[2] = 5;
|
||||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
||||||
other_tensor.ShareData(tensor);
|
|
||||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
|
other_tensor.ShareData(tensor);
|
||||||
EXPECT_TRUE(tensor.data() != nullptr);
|
EXPECT_TRUE(tensor.data() != nullptr);
|
||||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
@ -135,10 +135,10 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
|
|||||||
alternate_dims[0] = 2 * 3 * 5;
|
alternate_dims[0] = 2 * 3 * 5;
|
||||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||||
Tensor<TypeParam, CPUContext> other_tensor(alternate_dims);
|
Tensor<TypeParam, CPUContext> other_tensor(alternate_dims);
|
||||||
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
other_tensor.ShareData(tensor);
|
other_tensor.ShareData(tensor);
|
||||||
EXPECT_EQ(other_tensor.ndim(), 1);
|
EXPECT_EQ(other_tensor.ndim(), 1);
|
||||||
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
|
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
|
||||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
|
||||||
EXPECT_TRUE(tensor.data() != nullptr);
|
EXPECT_TRUE(tensor.data() != nullptr);
|
||||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
@ -149,35 +149,30 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TensorCPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
|
|
||||||
vector<int> dims(3);
|
|
||||||
dims[0] = 2;
|
|
||||||
dims[1] = 3;
|
|
||||||
dims[2] = 5;
|
|
||||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
|
||||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
|
||||||
other_tensor.ShareData(tensor);
|
|
||||||
ASSERT_DEATH(other_tensor.mutable_data(), "");
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(TensorCPUDeathTest, CannotDoReshapewithAlias) {
|
TYPED_TEST(TensorCPUTest, NoLongerSharesAfterReshape) {
|
||||||
vector<int> dims(3);
|
vector<int> dims(3);
|
||||||
dims[0] = 2;
|
dims[0] = 2;
|
||||||
dims[1] = 3;
|
dims[1] = 3;
|
||||||
dims[2] = 5;
|
dims[2] = 5;
|
||||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
||||||
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
other_tensor.ShareData(tensor);
|
other_tensor.ShareData(tensor);
|
||||||
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
|
auto* old_pointer = other_tensor.data();
|
||||||
|
|
||||||
dims[0] = 7;
|
dims[0] = 7;
|
||||||
tensor.Reshape(dims);
|
tensor.Reshape(dims);
|
||||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
EXPECT_EQ(old_pointer, other_tensor.data());
|
||||||
ASSERT_DEATH(other_tensor.data(), ".*Source data size has changed..*");
|
EXPECT_NE(old_pointer, tensor.mutable_data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) {
|
TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||||
Tensor<TypeParam, CPUContext> tensor;
|
Tensor<TypeParam, CPUContext> tensor;
|
||||||
EXPECT_EQ(tensor.ndim(), 0);
|
EXPECT_EQ(tensor.ndim(), 0);
|
||||||
ASSERT_DEATH(tensor.data(), ".*Check failed: 'data_' Must be non NULL.*");
|
ASSERT_DEATH(tensor.data(), "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,45 +63,55 @@ TYPED_TEST(TensorGPUTest, TensorShareData) {
|
|||||||
dims[2] = 5;
|
dims[2] = 5;
|
||||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
||||||
other_tensor.ShareData(tensor);
|
|
||||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
|
other_tensor.ShareData(tensor);
|
||||||
EXPECT_TRUE(tensor.data() != nullptr);
|
EXPECT_TRUE(tensor.data() != nullptr);
|
||||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TensorGPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
|
TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
|
||||||
vector<int> dims(3);
|
vector<int> dims(3);
|
||||||
dims[0] = 2;
|
dims[0] = 2;
|
||||||
dims[1] = 3;
|
dims[1] = 3;
|
||||||
dims[2] = 5;
|
dims[2] = 5;
|
||||||
|
vector<int> alternate_dims(1);
|
||||||
|
alternate_dims[0] = 2 * 3 * 5;
|
||||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
Tensor<TypeParam, CUDAContext> other_tensor(alternate_dims);
|
||||||
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
other_tensor.ShareData(tensor);
|
other_tensor.ShareData(tensor);
|
||||||
ASSERT_DEATH(other_tensor.mutable_data(), "");
|
EXPECT_EQ(other_tensor.ndim(), 1);
|
||||||
|
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
|
||||||
|
EXPECT_TRUE(tensor.data() != nullptr);
|
||||||
|
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||||
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TensorGPUDeathTest, CannotDoReshapewithAlias) {
|
TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
|
||||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
|
||||||
vector<int> dims(3);
|
vector<int> dims(3);
|
||||||
dims[0] = 2;
|
dims[0] = 2;
|
||||||
dims[1] = 3;
|
dims[1] = 3;
|
||||||
dims[2] = 5;
|
dims[2] = 5;
|
||||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
||||||
|
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||||
other_tensor.ShareData(tensor);
|
other_tensor.ShareData(tensor);
|
||||||
|
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||||
|
auto* old_pointer = other_tensor.data();
|
||||||
|
|
||||||
dims[0] = 7;
|
dims[0] = 7;
|
||||||
tensor.Reshape(dims);
|
tensor.Reshape(dims);
|
||||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
EXPECT_EQ(old_pointer, other_tensor.data());
|
||||||
ASSERT_DEATH(other_tensor.data(), "Source data size has changed.");
|
EXPECT_NE(old_pointer, tensor.mutable_data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||||
Tensor<TypeParam, CUDAContext> tensor;
|
Tensor<TypeParam, CUDAContext> tensor;
|
||||||
EXPECT_EQ(tensor.ndim(), 0);
|
EXPECT_EQ(tensor.ndim(), 0);
|
||||||
ASSERT_DEATH(tensor.data(), "Check failed: 'data_' Must be non NULL");
|
ASSERT_DEATH(tensor.data(), "");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
Reference in New Issue
Block a user