mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Simplify the data sharing mechanism, using std::shared_ptr instead of home-brew code.
Also cleaned notebook notes a little bit.
This commit is contained in:
@ -71,25 +71,23 @@ class Blob {
|
||||
DISABLE_COPY_AND_ASSIGN(Blob);
|
||||
};
|
||||
|
||||
|
||||
template <typename dtype, class Context>
|
||||
class Tensor {
|
||||
public:
|
||||
Tensor() : ndim_(0), size_(0), data_(nullptr),
|
||||
own_data_(true), data_source_(nullptr) {}
|
||||
Tensor() : ndim_(0), size_(0), data_(nullptr) {}
|
||||
|
||||
// Creates a tensor. The actual data allocation is going to be carried out
|
||||
// till the first time mutable_data() is called, so there is no overhead of
|
||||
// creating multiple tensors just as placeholders (although I haven't got a
|
||||
// clear idea where such cases would happen).
|
||||
explicit Tensor(const vector<int>& dims)
|
||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
||||
: data_(nullptr) {
|
||||
Reshape(dims);
|
||||
}
|
||||
|
||||
template <class SrcContext>
|
||||
Tensor(const Tensor<dtype, SrcContext>& src, Context* context)
|
||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
||||
: data_(nullptr) {
|
||||
Reshape(src.dims());
|
||||
context->template Copy<dtype, Context, SrcContext>(
|
||||
mutable_data(), src.data(), src.size());
|
||||
@ -98,7 +96,7 @@ class Tensor {
|
||||
// Creates a tensor, and fills its contents with the given values. We need to
|
||||
// have a context passed in as the copy function is device dependent.
|
||||
Tensor(const vector<int>& dims, vector<dtype> values, Context* context)
|
||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
||||
: data_(nullptr) {
|
||||
Reshape(dims);
|
||||
CHECK_EQ(values.size(), size_);
|
||||
context->template Copy<dtype, Context, CPUContext>(
|
||||
@ -107,15 +105,13 @@ class Tensor {
|
||||
|
||||
// Special case of above: create a tensor of shape 1, and the given value.
|
||||
Tensor(const dtype& value, Context* context)
|
||||
: data_(nullptr), own_data_(true), data_source_(nullptr) {
|
||||
: data_(nullptr) {
|
||||
Reshape(std::vector<int>(1, 1));
|
||||
context->template Copy<dtype, Context, CPUContext>(
|
||||
mutable_data(), &value, 1);
|
||||
mutable_data(), &value, 1);
|
||||
}
|
||||
|
||||
virtual ~Tensor() {
|
||||
Free();
|
||||
}
|
||||
virtual ~Tensor() {}
|
||||
|
||||
void Reshape(const vector<int>& dims) {
|
||||
CHECK_GT(dims.size(), 0);
|
||||
@ -127,10 +123,10 @@ class Tensor {
|
||||
CHECK_GT(d, 0);
|
||||
new_size *= d;
|
||||
}
|
||||
// If the size changes, we will call Free(). The next data() call will
|
||||
// re-allocate the memory.
|
||||
if (data_ && size_ != new_size) {
|
||||
Free();
|
||||
// If the size changes, we will free the data. the next mutable_data() call
|
||||
// will create the data storage.
|
||||
if (data_.get() && size_ != new_size) {
|
||||
data_.reset();
|
||||
}
|
||||
size_ = new_size;
|
||||
}
|
||||
@ -142,11 +138,19 @@ class Tensor {
|
||||
|
||||
void ShareData(const Tensor& src) {
|
||||
// To share data, the sizes must be equal.
|
||||
// The reason we do not force the ShareData to have an explicit reshape is
|
||||
// because we want to allow tensors to have different shape but still
|
||||
// maintain the same underlying data storage, as long as the contents are
|
||||
// of the same size.
|
||||
CHECK_EQ(src.size_, size_)
|
||||
<< "Size mismatch - did you call reshape before sharing the data?";
|
||||
if (data_) Free();
|
||||
own_data_ = false;
|
||||
data_source_ = &src;
|
||||
// It is possible that the source tensor hasn't called mutable_data() yet,
|
||||
// in which case ShareData() does make much sense since we don't really know
|
||||
// what to share yet.
|
||||
CHECK(src.data_.get())
|
||||
<< "Source tensor has no content yet.";
|
||||
// Finally, do sharing.
|
||||
data_ = src.data_;
|
||||
}
|
||||
|
||||
inline int ndim() const { return ndim_; }
|
||||
@ -159,49 +163,26 @@ class Tensor {
|
||||
}
|
||||
|
||||
const dtype* data() const {
|
||||
if (own_data_) {
|
||||
CHECK_NOTNULL(data_);
|
||||
return data_;
|
||||
} else {
|
||||
CHECK_NOTNULL(data_source_);
|
||||
CHECK_EQ(data_source_->size_, size_) << "Source data size has changed.";
|
||||
CHECK_NOTNULL(data_source_->data());
|
||||
return data_source_->data();
|
||||
}
|
||||
CHECK_NOTNULL(data_.get());
|
||||
return data_.get();
|
||||
}
|
||||
|
||||
dtype* mutable_data() {
|
||||
CHECK(own_data_) << "Cannot call mutable_data() from a shared tensor.";
|
||||
CHECK_GT(size_, 0) << "Cannot call mutable_data on a size 0 tensor.";
|
||||
if (!data_) Allocate();
|
||||
CHECK_NOTNULL(data_);
|
||||
return data_;
|
||||
if (!data_.get()) Allocate();
|
||||
return data_.get();
|
||||
}
|
||||
|
||||
void Allocate() {
|
||||
CHECK(data_ == nullptr);
|
||||
CHECK_GT(size_, 0);
|
||||
data_ = static_cast<dtype*>(Context::New(size_ * sizeof(dtype)));
|
||||
}
|
||||
|
||||
void Free() {
|
||||
if (own_data_) {
|
||||
if (data_) {
|
||||
Context::Delete(data_);
|
||||
}
|
||||
}
|
||||
own_data_ = true;
|
||||
data_ = nullptr;
|
||||
data_.reset(static_cast<dtype*>(Context::New(size_ * sizeof(dtype))),
|
||||
Context::Delete);
|
||||
}
|
||||
|
||||
protected:
|
||||
int ndim_;
|
||||
vector<int> dims_;
|
||||
int size_;
|
||||
dtype* data_;
|
||||
bool own_data_;
|
||||
const Tensor* data_source_;
|
||||
|
||||
std::shared_ptr<dtype> data_;
|
||||
DISABLE_COPY_AND_ASSIGN(Tensor);
|
||||
};
|
||||
|
||||
|
@ -114,8 +114,8 @@ TYPED_TEST(TensorCPUTest, TensorShareData) {
|
||||
dims[2] = 5;
|
||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_TRUE(tensor.data() != nullptr);
|
||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
@ -135,10 +135,10 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||
alternate_dims[0] = 2 * 3 * 5;
|
||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||
Tensor<TypeParam, CPUContext> other_tensor(alternate_dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_EQ(other_tensor.ndim(), 1);
|
||||
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
EXPECT_TRUE(tensor.data() != nullptr);
|
||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
@ -149,35 +149,30 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorCPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
||||
other_tensor.ShareData(tensor);
|
||||
ASSERT_DEATH(other_tensor.mutable_data(), "");
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorCPUDeathTest, CannotDoReshapewithAlias) {
|
||||
TYPED_TEST(TensorCPUTest, NoLongerSharesAfterReshape) {
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
Tensor<TypeParam, CPUContext> tensor(dims);
|
||||
Tensor<TypeParam, CPUContext> other_tensor(dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
auto* old_pointer = other_tensor.data();
|
||||
|
||||
dims[0] = 7;
|
||||
tensor.Reshape(dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
ASSERT_DEATH(other_tensor.data(), ".*Source data size has changed..*");
|
||||
EXPECT_EQ(old_pointer, other_tensor.data());
|
||||
EXPECT_NE(old_pointer, tensor.mutable_data());
|
||||
}
|
||||
|
||||
|
||||
TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||
Tensor<TypeParam, CPUContext> tensor;
|
||||
EXPECT_EQ(tensor.ndim(), 0);
|
||||
ASSERT_DEATH(tensor.data(), ".*Check failed: 'data_' Must be non NULL.*");
|
||||
ASSERT_DEATH(tensor.data(), "");
|
||||
}
|
||||
|
||||
|
||||
|
@ -63,45 +63,55 @@ TYPED_TEST(TensorGPUTest, TensorShareData) {
|
||||
dims[2] = 5;
|
||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_TRUE(tensor.data() != nullptr);
|
||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) {
|
||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
vector<int> alternate_dims(1);
|
||||
alternate_dims[0] = 2 * 3 * 5;
|
||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
||||
Tensor<TypeParam, CUDAContext> other_tensor(alternate_dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
ASSERT_DEATH(other_tensor.mutable_data(), "");
|
||||
EXPECT_EQ(other_tensor.ndim(), 1);
|
||||
EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]);
|
||||
EXPECT_TRUE(tensor.data() != nullptr);
|
||||
EXPECT_TRUE(other_tensor.data() != nullptr);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
}
|
||||
|
||||
TYPED_TEST(TensorGPUDeathTest, CannotDoReshapewithAlias) {
|
||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
TYPED_TEST(TensorGPUTest, NoLongerSharesAfterReshape) {
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
Tensor<TypeParam, CUDAContext> tensor(dims);
|
||||
Tensor<TypeParam, CUDAContext> other_tensor(dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
other_tensor.ShareData(tensor);
|
||||
EXPECT_EQ(tensor.data(), other_tensor.data());
|
||||
auto* old_pointer = other_tensor.data();
|
||||
|
||||
dims[0] = 7;
|
||||
tensor.Reshape(dims);
|
||||
EXPECT_TRUE(tensor.mutable_data() != nullptr);
|
||||
ASSERT_DEATH(other_tensor.data(), "Source data size has changed.");
|
||||
EXPECT_EQ(old_pointer, other_tensor.data());
|
||||
EXPECT_NE(old_pointer, tensor.mutable_data());
|
||||
}
|
||||
|
||||
|
||||
TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) {
|
||||
::testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
Tensor<TypeParam, CUDAContext> tensor;
|
||||
EXPECT_EQ(tensor.ndim(), 0);
|
||||
ASSERT_DEATH(tensor.data(), "Check failed: 'data_' Must be non NULL");
|
||||
ASSERT_DEATH(tensor.data(), "");
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
Reference in New Issue
Block a user