mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support multidimensional inputs to torch::tensor (#26210)
Summary: This PR adds support for multidimensional inputs to `torch::tensor`, to match the Python `torch.tensor` API. Closes https://github.com/pytorch/pytorch/issues/16099. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26210 Differential Revision: D17456761 Pulled By: yf225 fbshipit-source-id: a53ce74c535c13c5dcb833f19e9b6b79d12376b5
This commit is contained in:
committed by
Facebook Github Bot
parent
436c60a854
commit
aad0263a6b
@ -9,6 +9,8 @@
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
#include <test/cpp/common/support.h>
|
||||
|
||||
template <typename T>
|
||||
bool exactly_equal(at::Tensor left, T right) {
|
||||
return left.item<T>() == right;
|
||||
@ -195,6 +197,80 @@ TEST(TensorTest, ContainsCorrectValuesForManyValuesVariable) {
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
}
|
||||
|
||||
TEST(TensorTest, MultidimTensorCtor) {
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}, {3, 4}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({2, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({2, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes())));
|
||||
ASSERT_TRUE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}),
|
||||
"Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, MultidimTensorCtor_CUDA) {
|
||||
{
|
||||
auto tensor = torch::tensor(
|
||||
{{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}},
|
||||
torch::dtype(torch::kDouble).device(torch::kCUDA));
|
||||
ASSERT_TRUE(tensor.device().is_cuda());
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(
|
||||
tensor,
|
||||
torch::arange(1, 10, torch::kDouble).view(tensor.sizes()).to(torch::kCUDA)));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, PrettyPrintListInitTensor) {
|
||||
{
|
||||
ASSERT_EQ(
|
||||
c10::str(torch::detail::ListInitTensor(1.1)),
|
||||
"1.1");
|
||||
}
|
||||
{
|
||||
ASSERT_EQ(
|
||||
c10::str(torch::detail::ListInitTensor({1.1, 2.2})),
|
||||
"{1.1, 2.2}");
|
||||
}
|
||||
{
|
||||
ASSERT_EQ(
|
||||
c10::str(torch::detail::ListInitTensor({{1, 2}, {3, 4}})),
|
||||
"{{1, 2}, {3, 4}}");
|
||||
}
|
||||
{
|
||||
ASSERT_EQ(
|
||||
c10::str(torch::detail::ListInitTensor({{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}})),
|
||||
"{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, ContainsCorrectValuesWhenConstructedFromVector) {
|
||||
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
|
||||
auto tensor = at::tensor(v);
|
||||
|
Reference in New Issue
Block a user