mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use at::kLong for torch::tensor(integer_value) when dtype is not specified (#29066)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29066 This PR is BC-breaking in the following way: Previously, C++ `torch::tensor` with an integer literal or a braced-init-list of integer literals produces a tensor with dtype being the type of the integer literal(s). After this PR, it always produces a tensor of dtype `at::kLong` (aka. int64_t), matching Python `torch.tensor` behavior. Test Plan: Imported from OSS Differential Revision: D18307248 Pulled By: yf225 fbshipit-source-id: 7a8a2eefa113cbb238f23264843bdb3b77fec668
This commit is contained in:
committed by
Facebook Github Bot
parent
1189f559cc
commit
026fd36c71
@ -220,14 +220,16 @@ TEST(TensorTest, AtTensorCtorSingleDim) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorScalar) {
|
||||
TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
|
||||
auto tensor = torch::tensor(123);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kInt);
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_EQ(tensor.item<int32_t>(), 123);
|
||||
}
|
||||
|
||||
tensor = torch::tensor(123.456f);
|
||||
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
|
||||
auto tensor = torch::tensor(123.456f);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
@ -244,8 +246,10 @@ TEST(TensorTest, TorchTensorCtorScalar) {
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
|
||||
}
|
||||
|
||||
tensor = torch::tensor(true);
|
||||
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
|
||||
auto tensor = torch::tensor(true);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kBool);
|
||||
@ -258,12 +262,12 @@ TEST(TensorTest, TorchTensorCtorScalar) {
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], true));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorSingleDim) {
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
|
||||
auto tensor = torch::tensor({1, 2, 3});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kInt);
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
@ -286,7 +290,27 @@ TEST(TensorTest, TorchTensorCtorSingleDim) {
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
|
||||
tensor = torch::tensor({1.5, 2.25, 3.125});
|
||||
tensor = torch::tensor(at::ArrayRef<int64_t>({1, 2, 3}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
|
||||
tensor = torch::tensor(std::vector<int64_t>({1, 2, 3}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kLong);
|
||||
ASSERT_TRUE(exactly_equal(tensor[0], 1));
|
||||
ASSERT_TRUE(exactly_equal(tensor[1], 2));
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
|
||||
auto tensor = torch::tensor({1.5, 2.25, 3.125});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
@ -311,8 +335,10 @@ TEST(TensorTest, TorchTensorCtorSingleDim) {
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
}
|
||||
|
||||
tensor = torch::tensor({true, false, true});
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
|
||||
auto tensor = torch::tensor({true, false, true});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
@ -331,24 +357,59 @@ TEST(TensorTest, TorchTensorCtorSingleDim) {
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], true));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDim) {
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{true, false}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kBool);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
auto expected = torch::empty(tensor.sizes(), torch::kBool);
|
||||
expected[0][0] = true;
|
||||
expected[0][1] = false;
|
||||
ASSERT_TRUE(torch::equal(tensor, expected));
|
||||
auto tensor = torch::tensor({{1}, {2}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{1, 2}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{1}, {2}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}, {3, 4}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kLong);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
|
||||
{
|
||||
auto tensor = torch::tensor({{1.0, 2.0}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
@ -357,17 +418,23 @@ TEST(TensorTest, TorchTensorCtorMultiDim) {
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
|
||||
{
|
||||
auto tensor = torch::tensor({{1}, {2}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
auto tensor = torch::tensor({{true, false}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kBool);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
auto expected = torch::empty(tensor.sizes(), torch::kBool);
|
||||
expected[0][0] = true;
|
||||
expected[0][1] = false;
|
||||
ASSERT_TRUE(torch::equal(tensor, expected));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
@ -380,27 +447,16 @@ TEST(TensorTest, TorchTensorCtorMultiDim) {
|
||||
ASSERT_TRUE(torch::equal(tensor, expected));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimWithOptions) {
|
||||
{
|
||||
auto tensor = torch::tensor({{{1, 2}}});
|
||||
auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 2}));
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{1}, {2}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}, {3, 4}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
@ -408,20 +464,16 @@ TEST(TensorTest, TorchTensorCtorMultiDim) {
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes())));
|
||||
ASSERT_TRUE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}),
|
||||
"Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
|
||||
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Double");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
|
||||
@ -429,25 +481,11 @@ TEST(TensorTest, TorchTensorCtorMultiDim) {
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Long");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::full({1}, 1, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kInt);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Long");
|
||||
}
|
||||
}
|
||||
|
||||
@ -518,14 +556,14 @@ TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorPreservesInitListDtype) {
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kInt);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kInt);
|
||||
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), torch::kDouble);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), torch::kDouble);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kInt);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kInt);
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), torch::kDouble);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), torch::kDouble);
|
||||
}
|
||||
|
Reference in New Issue
Block a user