Files
pytorch/test/cpp/api/tensor.cpp
Paweł Piskorski 2d9b2bcfba Extend TensorImpl with BackendMeta (#97429)
BackendMeta offers a binary interface for the backend to attach arbitrary data to TensorImpl. TensorImpl has exactly one "slot" for backend metadata, however backend is free to compose any structure that is opaque to the framework beyond iheriting standard BackendMeta base.

Change-Id: I670fcdd16dd1c2b00f7eaa1cbc5b5dfea59a6221

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97429
Approved by: https://github.com/ezyang
2023-04-04 23:47:03 +00:00

1261 lines
43 KiB
C++

#include <gtest/gtest.h>
#include <test/cpp/api/support.h>
#include <c10/util/irange.h>
#include <torch/torch.h>
#include <cmath>
#include <cstddef>
#include <vector>
#include <test/cpp/common/support.h>
using namespace torch::test;
template <typename T>
bool exactly_equal(at::Tensor left, T right) {
return left.item<T>() == right;
}
template <typename T>
bool almost_equal(at::Tensor left, T right, double tolerance = 1e-4) {
return std::abs(left.item<T>() - right) < tolerance;
}
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
ASSERT_TRUE( \
tensor.device().type() == at::Device((device_), (index_)).type()); \
ASSERT_TRUE( \
tensor.device().index() == at::Device((device_), (index_)).index()); \
ASSERT_EQ(tensor.dtype(), (type_)); \
ASSERT_TRUE(tensor.layout() == (layout_))
TEST(TensorTest, ToDtype) {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
tensor = tensor.to(at::kInt);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
tensor = tensor.to(at::kChar);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
tensor = tensor.to(at::kDouble);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
tensor = tensor.to(at::TensorOptions(at::kInt));
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
tensor = tensor.to(at::TensorOptions(at::kChar));
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
tensor = tensor.to(at::TensorOptions(at::kDouble));
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
}
TEST(TensorTest, ToTensorAndTensorAttributes) {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
auto other = at::empty({3, 4}, at::kInt);
tensor = tensor.to(other);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
other = at::empty({3, 4}, at::kDouble);
tensor = tensor.to(other.dtype());
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
tensor = tensor.to(other.device());
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
other = at::empty({3, 4}, at::kLong);
tensor = tensor.to(other.device(), other.dtype());
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kLong, at::kStrided);
other = at::empty({3, 4}, at::kInt);
tensor = tensor.to(other.options());
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
}
// Not currently supported.
// TEST(TensorTest, ToLayout) {
// auto tensor = at::empty({3, 4});
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
//
// tensor = tensor.to(at::kSparse);
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kSparse);
//
// tensor = tensor.to(at::kStrided);
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
// }
TEST(TensorTest, ToOptionsWithRequiresGrad) {
{
// Respects requires_grad
auto tensor = torch::empty({3, 4}, at::requires_grad());
ASSERT_TRUE(tensor.requires_grad());
tensor = tensor.to(at::kDouble);
ASSERT_TRUE(tensor.requires_grad());
// Throws if requires_grad is set in TensorOptions
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
ASSERT_THROW(
tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
// Doesn't throw if requires_grad is not set
tensor.to(at::TensorOptions());
tensor.to(at::TensorOptions().requires_grad(false));
}
{
auto tensor = torch::empty({3, 4});
ASSERT_FALSE(tensor.requires_grad());
// Respects requires_grad
tensor = tensor.to(at::kDouble);
ASSERT_FALSE(tensor.requires_grad());
// Throws if requires_grad is set in TensorOptions
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
ASSERT_THROW(
tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
// Doesn't throw if requires_grad is not set
tensor.to(at::TensorOptions());
tensor.to(at::TensorOptions().requires_grad(false));
}
}
TEST(TensorTest, ToDoesNotCopyWhenOptionsAreAllTheSame) {
{
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(at::kFloat);
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
}
{
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(tensor.options());
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
}
{
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(tensor.dtype());
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
}
{
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(tensor.device());
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
}
{
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(tensor);
ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
}
}
TEST(TensorTest, AtTensorCtorScalar) {
auto tensor = at::tensor(123);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kInt);
ASSERT_EQ(tensor[0].item<int32_t>(), 123);
tensor = at::tensor(123.456f);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kFloat);
ASSERT_TRUE(almost_equal(tensor[0], 123.456f));
tensor = at::tensor(123.456);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
tensor = at::tensor(123, at::dtype(at::kFloat)) + 0.5;
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kFloat);
ASSERT_TRUE(almost_equal(tensor[0], 123.5));
tensor = at::tensor(c10::complex<float>(1.0, 2.0)) + 0.5;
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
tensor =
at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexFloat)) +
0.5;
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
tensor = at::tensor(c10::complex<double>(1.0, 2.0)) + 0.5;
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
tensor =
at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexDouble)) +
0.5;
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
}
TEST(TensorTest, AtTensorCtorSingleDim) {
auto tensor = at::tensor({1, 2, 3});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kInt);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = at::tensor(std::vector<int>({1, 2, 3}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kInt);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = at::tensor({1.5, 2.25, 3.125});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = at::tensor(
{c10::complex<float>(1.5, 0.15),
c10::complex<float>(1.5, 0.15),
c10::complex<float>(3.125, 0.3125)});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
tensor = at::tensor(
{c10::complex<double>(1.5, 0.15),
c10::complex<double>(1.5, 0.15),
c10::complex<double>(3.125, 0.3125)});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
tensor = at::tensor({1.1, 2.2, 3.3}, at::dtype(at::kInt));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kInt);
ASSERT_EQ(tensor.layout(), at::kStrided);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = at::tensor(std::vector<double>({1.5, 2.25, 3.125}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = at::tensor(std::vector<c10::complex<float>>(
{c10::complex<float>(1.5, 0.15),
c10::complex<float>(1.5, 0.15),
c10::complex<float>(3.125, 0.3125)}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
tensor = at::tensor(std::vector<c10::complex<double>>(
{c10::complex<double>(1.5, 0.15),
c10::complex<double>(1.5, 0.15),
c10::complex<double>(3.125, 0.3125)}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
tensor = at::tensor(v);
ASSERT_EQ(tensor.numel(), v.size());
ASSERT_EQ(tensor.dtype(), at::kInt);
for (const auto i : c10::irange(v.size())) {
ASSERT_TRUE(exactly_equal(tensor[i], v.at(i)));
}
std::vector<double> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
tensor = at::tensor(w);
ASSERT_EQ(tensor.numel(), w.size());
ASSERT_EQ(tensor.dtype(), at::kDouble);
for (const auto i : c10::irange(w.size())) {
ASSERT_TRUE(almost_equal(tensor[i], w.at(i)));
}
std::vector<c10::complex<double>> x = {
{1.1, -1.1},
{2.2, -2.2},
{3.3, -3.3},
{4.4, -4.4},
{5.5, -5.5},
{6.6, -6.6},
{7.7, -7.7},
{8.8, -8.8},
{9.9, -9.9},
{10.0, -10.0}};
tensor = at::tensor(x);
ASSERT_EQ(tensor.numel(), x.size());
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
for (const auto i : c10::irange(x.size())) {
ASSERT_TRUE(almost_equal(tensor[i], x.at(i)));
}
}
TEST(TensorTest, AtTensorCastRealToComplex) {
auto tensor =
at::tensor(std::vector<double>({1.5, 2.5, 3.5}), at::kComplexDouble);
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
tensor = at::tensor({1.5, 2.5, 3.5}, at::kComplexDouble);
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
tensor = at::tensor(1.5, at::kComplexDouble);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
}
TEST(TensorTest, AtTensorCastComplexToRealErrorChecks) {
{
ASSERT_THROWS_WITH(
at::tensor(c10::complex<float>(0.1, 0.2), at::kFloat),
"\"tensor_cpu\" not implemented for 'Float'");
}
{
ASSERT_THROWS_WITH(
at::tensor({c10::complex<float>(0.1, 0.2)}, at::kFloat),
"\"tensor_cpu\" not implemented for 'Float'");
}
{
ASSERT_THROWS_WITH(
at::tensor(
std::vector<c10::complex<float>>{c10::complex<float>(0.1, 0.2)},
at::kFloat),
"\"tensor_cpu\" not implemented for 'Float'");
}
}
TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
auto tensor = torch::tensor(123);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_EQ(tensor.item<int64_t>(), 123);
}
void test_TorchTensorCtorScalarFloatingType_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor(123.456f);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor, 123.456f));
tensor = torch::tensor(123.456);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor, 123.456));
tensor = torch::tensor({123.456});
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
}
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
test_TorchTensorCtorScalarFloatingType_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorScalarFloatingType_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
auto tensor = torch::tensor(true);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
ASSERT_EQ(tensor.dtype(), at::kBool);
ASSERT_TRUE(exactly_equal(tensor, true));
tensor = torch::tensor({true});
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
ASSERT_EQ(tensor.dtype(), at::kBool);
ASSERT_TRUE(exactly_equal(tensor[0], true));
}
TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
auto tensor = torch::tensor({1, 2, 3});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = torch::tensor(at::ArrayRef<int>({1, 2, 3}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = torch::tensor(std::vector<int>({1, 2, 3}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = torch::tensor(at::ArrayRef<int64_t>({1, 2, 3}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
tensor = torch::tensor(std::vector<int64_t>({1, 2, 3}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kLong);
ASSERT_TRUE(exactly_equal(tensor[0], 1));
ASSERT_TRUE(exactly_equal(tensor[1], 2));
ASSERT_TRUE(exactly_equal(tensor[2], 3));
}
void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor({1.5, 2.25, 3.125});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor({1.5f, 2.25f, 3.125f});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
tensor = torch::tensor(at::ArrayRef<float>({1.5f, 2.25f, 3.125f}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor(std::vector<float>({1.5f, 2.25f, 3.125f}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor(at::ArrayRef<double>({1.5, 2.25, 3.125}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor(std::vector<double>({1.5, 2.25, 3.125}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
}
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
auto tensor = torch::tensor({true, false, true});
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kBool);
ASSERT_TRUE(exactly_equal(tensor[0], true));
ASSERT_TRUE(exactly_equal(tensor[1], false));
ASSERT_TRUE(exactly_equal(tensor[2], true));
tensor = torch::tensor(at::ArrayRef<bool>({true, false, true}));
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kBool);
ASSERT_TRUE(exactly_equal(tensor[0], true));
ASSERT_TRUE(exactly_equal(tensor[1], false));
ASSERT_TRUE(exactly_equal(tensor[2], true));
}
TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
{
auto tensor = torch::tensor({{1, 2}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{1}, {2}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{1, 2}}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{1}, {2}}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 1}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{1, 2}, {3, 4}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(
tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
ASSERT_TRUE(torch::allclose(
tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}});
ASSERT_EQ(tensor.dtype(), torch::kLong);
ASSERT_EQ(
tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
}
void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
{
auto tensor = torch::tensor({{1.0, 2.0}});
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor(
{{{{{{{{1.0, 2.0, 3.0}}}}},
{{{{{4.0, 5.0, 6.0}}}}},
{{{{{7.0, 8.0, 9.0}}}}}}}});
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
{
auto tensor = torch::tensor({{true, false}});
ASSERT_EQ(tensor.dtype(), torch::kBool);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
auto expected = torch::empty(tensor.sizes(), torch::kBool);
expected[0][0] = true;
expected[0][1] = false;
ASSERT_TRUE(torch::equal(tensor, expected));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{true}, {false}});
ASSERT_EQ(tensor.dtype(), torch::kBool);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
auto expected = torch::empty(tensor.sizes(), torch::kBool);
expected[0][0] = true;
expected[1][0] = false;
ASSERT_TRUE(torch::equal(tensor, expected));
ASSERT_FALSE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorMultiDimWithOptions) {
{
auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt));
ASSERT_EQ(tensor.dtype(), torch::kInt);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor(
{{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
ASSERT_TRUE(torch::allclose(
tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes())));
ASSERT_TRUE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
{
ASSERT_THROWS_WITH(
torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}),
"Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1");
}
{
ASSERT_THROWS_WITH(
torch::tensor({{{1, 2.0}, {1, 2.0}}}),
"Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
}
{
ASSERT_THROWS_WITH(
torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
}
{
ASSERT_THROWS_WITH(
torch::tensor({{{true}, {2}}}),
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
}
{
ASSERT_THROWS_WITH(
torch::tensor({{{true, 2}}}),
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
}
}
TEST(TensorTest, TorchTensorCastRealToComplex) {
auto tensor = torch::tensor(
std::vector<double>({1.5, 2.5, 3.5}), torch::kComplexDouble);
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
tensor = torch::tensor({1.5, 2.5, 3.5}, torch::kComplexDouble);
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
tensor = torch::tensor(1.5, torch::kComplexDouble);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
ASSERT_TRUE(almost_equal(tensor, c10::complex<double>(1.5)));
}
TEST(TensorTest, TorchTensorCastComplexToRealErrorChecks) {
{
ASSERT_THROWS_WITH(
torch::tensor(c10::complex<float>(0.1, 0.2), torch::kFloat),
"value cannot be converted to type float without overflow");
}
{
ASSERT_THROWS_WITH(
torch::tensor(
{c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)},
torch::kFloat),
"value cannot be converted to type float without overflow");
}
{
ASSERT_THROWS_WITH(
torch::tensor(
std::vector<c10::complex<float>>{
c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)},
torch::kFloat),
"can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
}
}
void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor(
{{{{{{{{1.0, 2.0, 3.0}}}}},
{{{{{4.0, 5.0, 6.0}}}}},
{{{{{7.0, 8.0, 9.0}}}}}}}},
torch::dtype(default_dtype).device(torch::kCUDA));
ASSERT_TRUE(tensor.device().is_cuda());
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
ASSERT_TRUE(torch::allclose(
tensor,
torch::arange(1, 10, default_dtype)
.view(tensor.sizes())
.to(torch::kCUDA)));
ASSERT_FALSE(tensor.requires_grad());
}
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
void test_TorchTensorCtorZeroSizedDim_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
{
auto tensor = torch::tensor({});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{}, {}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{}, {}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{}}}}, {{{{}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 2, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(
tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
test_TorchTensorCtorZeroSizedDim_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorZeroSizedDim_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
ASSERT_EQ(
torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(),
default_dtype);
}
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
ASSERT_EQ(
torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
ASSERT_EQ(
torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
ASSERT_EQ(
torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
ASSERT_EQ(
torch::tensor(at::ArrayRef<int>({1, 2, 3}), torch::TensorOptions())
.dtype(),
torch::kLong);
ASSERT_EQ(
torch::tensor(std::vector<int>({1, 2, 3}), torch::TensorOptions())
.dtype(),
torch::kLong);
ASSERT_EQ(
torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor(at::ArrayRef<double>({1., 2., 3.}), torch::TensorOptions())
.dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor(std::vector<double>({1., 2., 3.}), torch::TensorOptions())
.dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor(
at::ArrayRef<float>({1.f, 2.f, 3.f}), torch::TensorOptions())
.dtype(),
default_dtype);
ASSERT_EQ(
torch::tensor(std::vector<float>({1.f, 2.f, 3.f}), torch::TensorOptions())
.dtype(),
default_dtype);
}
TEST(TensorTest, TorchTensorCtorWithNonDtypeOptions) {
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
/*default_dtype=*/torch::kDouble);
}
void test_Arange_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
ASSERT_EQ(torch::arange(0., 5).dtype(), default_dtype);
}
TEST(TensorTest, Arange) {
{
auto x = torch::arange(0, 5);
ASSERT_EQ(x.dtype(), torch::kLong);
}
test_Arange_expected_dtype(torch::kFloat);
test_Arange_expected_dtype(torch::kDouble);
}
TEST(TensorTest, PrettyPrintTensorDataContainer) {
{ ASSERT_EQ(c10::str(torch::detail::TensorDataContainer(1.1)), "1.1"); }
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer({1.1, 2.2})), "{1.1, 2.2}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer({{1, 2}, {3, 4}})),
"{{1, 2}, {3, 4}}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer(
{{{{{{{{1.1, 2.2, 3.3}}}}},
{{{{{4.4, 5.5, 6.6}}}}},
{{{{{7.7, 8.8, 9.9}}}}}}}})),
"{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1}}}}}}}}}})),
"{{{{{{{{{{1}}}}}}}}}}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{}}}}}}}}}})),
"{{{{{{{{{{}}}}}}}}}}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1, 2}}}}}}}}}})),
"{{{{{{{{{{1, 2}}}}}}}}}}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer(
at::ArrayRef<double>({1.1, 2.2}))),
"{1.1, 2.2}");
}
{
ASSERT_EQ(
c10::str(torch::detail::TensorDataContainer(
std::vector<double>({1.1, 2.2}))),
"{1.1, 2.2}");
}
}
TEST(TensorTest, TensorDataContainerCallingAccessorOfWrongType) {
{
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer(1.1).init_list(),
"Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer(1.1).tensor(),
"Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
}
{
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer({1.1, 2.2}).scalar(),
"Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer({1.1, 2.2}).tensor(),
"Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
}
{
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2}))
.scalar(),
"Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
ASSERT_THROWS_WITH(
torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2}))
.init_list(),
"Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
}
}
TEST(TensorTest, FromBlob) {
std::vector<double> v = {1.0, 2.0, 3.0};
auto tensor = torch::from_blob(
v.data(), v.size(), torch::dtype(torch::kFloat64).requires_grad(true));
ASSERT_TRUE(tensor.requires_grad());
ASSERT_EQ(tensor.dtype(), torch::kFloat64);
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor[0].item<double>(), 1);
ASSERT_EQ(tensor[1].item<double>(), 2);
ASSERT_EQ(tensor[2].item<double>(), 3);
// Above syntax did not copy the data, and has nullptr deleter context.
ASSERT_EQ(tensor.storage().data_ptr().get_context(), nullptr);
}
TEST(TensorTest, FromBlobUsesDeleter) {
bool called = false;
{
std::vector<int32_t> v = {1, 2, 3};
auto tensor = torch::from_blob(
v.data(),
v.size(),
/*deleter=*/[&called](void* data) { called = true; },
torch::kInt32);
}
ASSERT_TRUE(called);
}
TEST(TensorTest, FromBlobWithStrides) {
// clang-format off
std::vector<int32_t> v = {
1, 2, 3,
4, 5, 6,
7, 8, 9
};
// clang-format on
auto tensor = torch::from_blob(
v.data(),
/*sizes=*/{3, 3},
/*strides=*/{1, 3},
torch::kInt32);
ASSERT_EQ(tensor.dtype(), torch::kInt32);
ASSERT_EQ(tensor.numel(), 9);
const std::vector<int64_t> expected_strides = {1, 3};
ASSERT_EQ(tensor.strides(), expected_strides);
for (const auto i : c10::irange(tensor.size(0))) {
for (const auto j : c10::irange(tensor.size(1))) {
// NOTE: This is column major because the strides are swapped.
EXPECT_EQ(tensor[i][j].item<int32_t>(), 1 + (j * tensor.size(1)) + i);
}
}
}
TEST(TensorTest, Item) {
{
torch::Tensor tensor = torch::tensor(3.14);
torch::Scalar scalar = tensor.item();
ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
}
{
torch::Tensor tensor = torch::tensor(123);
torch::Scalar scalar = tensor.item();
ASSERT_EQ(scalar.to<int>(), 123);
}
}
TEST(TensorTest, Item_CUDA) {
{
torch::Tensor tensor = torch::tensor(3.14, torch::kCUDA);
torch::Scalar scalar = tensor.item();
ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
}
{
torch::Tensor tensor = torch::tensor(123, torch::kCUDA);
torch::Scalar scalar = tensor.item();
ASSERT_EQ(scalar.to<int>(), 123);
}
}
TEST(TensorTest, DataPtr) {
auto tensor = at::empty({3, 4}, at::kFloat);
auto tensor_not_copy = tensor.to(tensor.options());
ASSERT_EQ(tensor_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr());
}
TEST(TensorTest, Data) {
const auto tensor = torch::rand({3, 3});
ASSERT_TRUE(torch::equal(tensor, tensor.data()));
}
TEST(TensorTest, BackwardAndGrad) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
y.backward();
ASSERT_EQ(x.grad().item<float>(), 10.0);
}
TEST(TensorTest, BackwardCreatesOnesGrad) {
const auto x =
torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
x.backward();
ASSERT_TRUE(torch::equal(x.grad(), torch::ones_like(x)));
}
TEST(TensorTest, BackwardNonScalarOutputs) {
auto x = torch::randn({5, 5}, torch::requires_grad());
auto y = x * x;
ASSERT_THROWS_WITH(
y.backward(), "grad can be implicitly created only for scalar outputs");
}
TEST(TensorTest, BackwardComplexScalarOutput) {
auto x = torch::randn({5, 5}, torch::requires_grad());
auto y = (x * c10::Scalar(c10::complex<float>(0, 0.5))).sum();
ASSERT_THROWS_WITH(
y.backward(), "grad can be computed only for real scalar outputs");
}
TEST(TensorTest, IsLeaf) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
ASSERT_TRUE(x.is_leaf());
ASSERT_FALSE(y.is_leaf());
}
TEST(TensorTest, OutputNr) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
ASSERT_EQ(x.output_nr(), 0);
ASSERT_EQ(y.output_nr(), 0);
}
TEST(TensorTest, Version) {
auto x = torch::ones(3);
ASSERT_EQ(x._version(), 0);
x.mul_(2);
ASSERT_EQ(x._version(), 1);
x.add_(1);
ASSERT_EQ(x._version(), 2);
}
TEST(TensorTest, Detach) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
const auto y_detached = y.detach();
ASSERT_FALSE(y.is_leaf());
ASSERT_TRUE(y_detached.is_leaf());
ASSERT_FALSE(y_detached.requires_grad());
}
TEST(TensorTest, DetachInplace) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;
auto y_detached = y.detach_();
ASSERT_TRUE(y.is_leaf());
ASSERT_FALSE(y.requires_grad());
ASSERT_TRUE(y_detached.is_leaf());
ASSERT_FALSE(y_detached.requires_grad());
}
TEST(TensorTest, SetData) {
auto x = torch::randn({5});
auto y = torch::randn({5});
ASSERT_FALSE(torch::equal(x, y));
ASSERT_NE(x.data_ptr<float>(), y.data_ptr<float>());
x.set_data(y);
ASSERT_TRUE(torch::equal(x, y));
ASSERT_EQ(x.data_ptr<float>(), y.data_ptr<float>());
}
TEST(TensorTest, RequiresGradInplace) {
auto x = torch::tensor({5.0});
x.requires_grad_(true);
ASSERT_TRUE(x.requires_grad());
auto y = x * x;
ASSERT_THROWS_WITH(
y.requires_grad_(false),
"you can only change requires_grad flags of leaf variables.");
x.requires_grad_(false);
ASSERT_FALSE(x.requires_grad());
const auto int_tensor =
torch::tensor({5}, at::TensorOptions().dtype(torch::kInt));
ASSERT_THROWS_WITH(
int_tensor.requires_grad_(true),
"Only Tensors of floating point and complex dtype can require gradients");
}
TEST(TensorTest, StdDimension) {
// Test that std(0) doesn't select the std(unbiased=False) overload (gh-40287)
auto x = torch::randn({4, 3});
auto std = x.std(0);
ASSERT_EQ(x.var(0).numel(), 3);
ASSERT_EQ(x.std(0).numel(), 3);
ASSERT_EQ(x.var(0, /*unbiased=*/true).numel(), 3);
ASSERT_EQ(x.std(0, /*unbiased=*/true).numel(), 3);
ASSERT_EQ(torch::var(x, 0).numel(), 3);
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0)).numel(), 3);
ASSERT_EQ(torch::std(x, 0).numel(), 3);
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0)).numel(), 3);
ASSERT_EQ(torch::var(x, 0, /*unbiased=*/true).numel(), 3);
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0, /*unbiased=*/true)).numel(), 3);
ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
}
TEST(TensorTest, ReshapeAlias) {
// Tests the behavior of the _reshape_alias private operator so
// that it matches the behavior of as_strided and view.
auto x = torch::randn({3, 3});
ASSERT_TRUE(torch::equal(
torch::_reshape_alias(x, {2, 2}, {1, 2}),
torch::as_strided(x, {2, 2}, {1, 2})));
ASSERT_TRUE(torch::equal(torch::_reshape_alias(x, {9}, {1}), x.view({-1})));
// Test that the backward works fine.
auto y = torch::randn({3, 3}, torch::requires_grad(true));
auto z = torch::clone(y).detach().requires_grad_(true);
(y * y).view({-1}).mean().backward();
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
}
TEST(TensorTest, BackendMetadata) {
// Tests ability to assign custom backend metadata to tensor.
struct CustomBackendMetadata : public c10::BackendMeta {
mutable bool cloned_{false}; // for testing this field will mutate when
// clone() is called by shallow_copy_from.
c10::intrusive_ptr<c10::BackendMeta> clone(
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};
at::Tensor y;
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
CustomBackendMetadata* custom_tmeta{nullptr};
{
auto x = torch::ones({3, 3});
auto impl{x.unsafeGetTensorImpl()};
ASSERT_TRUE(impl != nullptr);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == nullptr);
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
impl->set_backend_meta(new_tmeta);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == new_tmeta);
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
ASSERT_TRUE(custom_tmeta != nullptr);
ASSERT_TRUE(custom_tmeta->cloned_ == false);
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
}
ASSERT_TRUE(
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
ASSERT_TRUE(custom_tmeta->cloned_ == true);
}