Create at::tensor (#8475)

This commit is contained in:
Peter Goldsborough
2018-06-20 11:44:21 -07:00
committed by GitHub
parent b4cd9f2fc9
commit 9335885b1b
9 changed files with 142 additions and 13 deletions

View File

@ -4,6 +4,18 @@
#include <ATen/ATen.h>
#include <cmath>
template <typename T>
bool exactly_equal(at::Tensor left, T right) {
return at::Scalar(left).to<T>() == right;
}
template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(at::Scalar(left).to<T>() - right) < tolerance;
}
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
@ -83,3 +95,69 @@ TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
auto hopefully_not_copy = tensor.to(at::kFloat);
REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
}
TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
auto tensor = at::tensor(123);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(tensor[0].toCInt() == 123);
tensor = at::tensor(123.456f);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kFloat);
REQUIRE(almost_equal(tensor[0], 123.456f));
tensor = at::tensor(123.456);
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kDouble);
REQUIRE(almost_equal(tensor[0], 123.456));
}
TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
auto tensor = at::tensor({1, 2, 3});
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(exactly_equal(tensor[0], 1));
REQUIRE(exactly_equal(tensor[1], 2));
REQUIRE(exactly_equal(tensor[2], 3));
tensor = at::tensor({1.5, 2.25, 3.125});
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kDouble);
REQUIRE(almost_equal(tensor[0], 1.5));
REQUIRE(almost_equal(tensor[1], 2.25));
REQUIRE(almost_equal(tensor[2], 3.125));
}
TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
auto tensor = at::tensor(v);
REQUIRE(tensor.numel() == v.size());
REQUIRE(tensor.dtype() == at::kInt);
for (size_t i = 0; i < v.size(); ++i) {
REQUIRE(exactly_equal(tensor[i], v.at(i)));
}
std::vector<float> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
tensor = at::tensor(w);
REQUIRE(tensor.numel() == w.size());
REQUIRE(tensor.dtype() == at::kFloat);
for (size_t i = 0; i < w.size(); ++i) {
REQUIRE(almost_equal(tensor[i], w.at(i)));
}
}
TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
auto tensor = at::tensor(123, dtype(at::kFloat)) + 0.5;
REQUIRE(tensor.numel() == 1);
REQUIRE(tensor.dtype() == at::kFloat);
REQUIRE(almost_equal(tensor[0], 123.5));
tensor = at::tensor({1.1, 2.2, 3.3}, dtype(at::kInt));
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(tensor.layout() == at::kStrided);
REQUIRE(exactly_equal(tensor[0], 1));
REQUIRE(exactly_equal(tensor[1], 2));
REQUIRE(exactly_equal(tensor[2], 3));
}