C++ API parity: at::Tensor::set_data

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26647

Test Plan: Imported from OSS

Differential Revision: D17542604

Pulled By: pbelevich

fbshipit-source-id: 37d5d67ebdb9348b5561d983f9bd26d310210983
This commit is contained in:
Pavel Belevich
2019-09-24 04:49:40 -07:00
committed by Facebook Github Bot
parent 2cf1183ec1
commit 450504cd95

View File

@ -489,3 +489,18 @@ TEST(TensorTest, DetachInplace) {
ASSERT_THROWS_WITH(x.detach_(), message);
ASSERT_THROWS_WITH(y.detach_(), message);
}
TEST(TensorTest, SetData) {
auto x = torch::randn({5});
auto y = torch::randn({5});
ASSERT_FALSE(torch::equal(x, y));
ASSERT_NE(x.data_ptr<float>(), y.data_ptr<float>());
x.set_data(y);
ASSERT_TRUE(torch::equal(x, y));
ASSERT_EQ(x.data_ptr<float>(), y.data_ptr<float>());
x = at::tensor({5});
y = at::tensor({5});
ASSERT_THROWS_WITH(x.set_data(y), "set_data is not implemented for Tensor");
}