mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use default dtype for torch::tensor(floating_point_values) and torch::tensor(empty braced-init-list) when dtype is not specified (#29632)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29632 This PR is BC-breaking in the following way: Previously, C++ `torch::tensor` with a floating-point literal with no suffix (e.g. `torch::tensor(1.1)`) or a (nested) braced-init-list of floating-point literals with no suffix (e.g. `torch::tensor({{1.1, 2.2}})` produces a tensor with dtype `at::kDouble`. After this PR, it produces a tensor with dtype `torch::get_default_dtype()`, matching Python `torch.tensor` behavior. Test Plan: Imported from OSS Differential Revision: D18465819 Pulled By: yf225 fbshipit-source-id: 6834fe50335c677bc3832f2a5e9cf8d1ede9f665
This commit is contained in:
committed by
Facebook Github Bot
parent
3fb9bbc99b
commit
2bcac59a30
@ -22,6 +22,7 @@ set(TORCH_API_TEST_SOURCES
|
||||
${TORCH_API_TEST_DIR}/sequential.cpp
|
||||
${TORCH_API_TEST_DIR}/serialize.cpp
|
||||
${TORCH_API_TEST_DIR}/static.cpp
|
||||
${TORCH_API_TEST_DIR}/support.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
|
||||
${TORCH_API_TEST_DIR}/tensor_options.cpp
|
||||
|
@ -462,7 +462,7 @@ TEST_F(FunctionalTest, GridSample) {
|
||||
TEST_F(FunctionalTest, AffineGrid) {
|
||||
{
|
||||
// 2D affine.
|
||||
auto theta = torch::arange(1, 13, torch::kDouble)
|
||||
auto theta = torch::arange(1., 13)
|
||||
.view(std::vector<int64_t>({2, 2, 3}));
|
||||
auto size = std::vector<int64_t>({2, 3, 2, 2});
|
||||
auto align_corners = true;
|
||||
@ -480,7 +480,7 @@ TEST_F(FunctionalTest, AffineGrid) {
|
||||
}
|
||||
{
|
||||
// 3D affine.
|
||||
auto theta = torch::arange(1, 13, torch::kDouble)
|
||||
auto theta = torch::arange(1., 13)
|
||||
.view(std::vector<int64_t>({1, 3, 4}));
|
||||
auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
|
||||
auto align_corners = true;
|
||||
|
@ -29,8 +29,9 @@ void check_exact_values(
|
||||
}
|
||||
|
||||
for (size_t p = 0; p < layerParameters.size(0); p++) {
|
||||
auto tensor = layerParameters[p];
|
||||
auto expectedTensor = expectedLayerParameters[p];
|
||||
// Always compare using double dtype, regardless of the original dtype of the tensors
|
||||
auto tensor = layerParameters[p].to(torch::kFloat64);
|
||||
auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
|
||||
|
||||
if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
|
||||
std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
|
||||
|
@ -92,16 +92,16 @@ void check_exact_values(
|
||||
assign_parameter(
|
||||
parameters,
|
||||
"0.weight",
|
||||
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}));
|
||||
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64));
|
||||
assign_parameter(
|
||||
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}));
|
||||
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
|
||||
assign_parameter(
|
||||
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}));
|
||||
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}));
|
||||
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
|
||||
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
|
||||
|
||||
auto optimizer = OptimizerClass(parameters.values(), options);
|
||||
torch::Tensor input =
|
||||
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}).reshape({3, 2});
|
||||
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64).reshape({3, 2});
|
||||
|
||||
for (size_t i = 0; i < kIterations; ++i) {
|
||||
optimizer.zero_grad();
|
||||
@ -116,8 +116,9 @@ void check_exact_values(
|
||||
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
|
||||
for (size_t p = 0; p < parameters.size(); ++p) {
|
||||
ASSERT_TRUE(parameters[p]->defined());
|
||||
auto computed = parameters[p]->flatten();
|
||||
auto expected = expected_parameters.at(i / kSampleEvery).at(p);
|
||||
// Always compare using double dtype, regardless of the original dtype of the tensors
|
||||
auto computed = parameters[p]->flatten().to(torch::kFloat64);
|
||||
auto expected = expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
|
||||
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
|
||||
std::cout << "Iteration " << i << ": " << computed
|
||||
<< " != " << expected << " (parameter " << p << ")"
|
||||
|
9
test/cpp/api/support.cpp
Normal file
9
test/cpp/api/support.cpp
Normal file
@ -0,0 +1,9 @@
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
namespace torch {
|
||||
namespace test {
|
||||
|
||||
std::mutex AutoDefaultDtypeMode::default_dtype_mutex;
|
||||
|
||||
} // namespace test
|
||||
} // namespace torch
|
@ -61,5 +61,24 @@ inline int count_substr_occurrences(const std::string& str, const std::string& s
|
||||
return count;
|
||||
}
|
||||
|
||||
// A RAII, thread local (!) guard that changes default dtype upon
|
||||
// construction, and sets it back to the original dtype upon destruction.
|
||||
//
|
||||
// Usage of this guard is synchronized across threads, so that at any given time,
|
||||
// only one guard can take effect.
|
||||
struct AutoDefaultDtypeMode {
|
||||
static std::mutex default_dtype_mutex;
|
||||
|
||||
AutoDefaultDtypeMode(c10::ScalarType default_dtype) : prev_default_dtype(torch::typeMetaToScalarType(torch::get_default_dtype())) {
|
||||
default_dtype_mutex.lock();
|
||||
torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
|
||||
}
|
||||
~AutoDefaultDtypeMode() {
|
||||
default_dtype_mutex.unlock();
|
||||
torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
|
||||
}
|
||||
c10::ScalarType prev_default_dtype;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace torch
|
||||
|
@ -9,6 +9,8 @@
|
||||
|
||||
#include <test/cpp/common/support.h>
|
||||
|
||||
using namespace torch::test;
|
||||
|
||||
template <typename T>
|
||||
bool exactly_equal(at::Tensor left, T right) {
|
||||
return left.item<T>() == right;
|
||||
@ -228,7 +230,9 @@ TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
|
||||
ASSERT_EQ(tensor.item<int32_t>(), 123);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
|
||||
void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
auto tensor = torch::tensor(123.456f);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
@ -238,16 +242,21 @@ TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
|
||||
tensor = torch::tensor(123.456);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor, 123.456));
|
||||
|
||||
tensor = torch::tensor({123.456});
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
|
||||
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
|
||||
auto tensor = torch::tensor(true);
|
||||
ASSERT_EQ(tensor.numel(), 1);
|
||||
@ -309,12 +318,40 @@ TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
|
||||
ASSERT_TRUE(exactly_equal(tensor[2], 3));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
|
||||
void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
auto tensor = torch::tensor({1.5, 2.25, 3.125});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
||||
tensor = torch::tensor({1.5f, 2.25f, 3.125f});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
|
||||
|
||||
tensor = torch::tensor(at::ArrayRef<float>({1.5, 2.25, 3.125}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
|
||||
tensor = torch::tensor(std::vector<float>({1.5, 2.25, 3.125}));
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
ASSERT_EQ(tensor.numel(), 3);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
|
||||
ASSERT_EQ(tensor.dtype(), at::kFloat);
|
||||
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
|
||||
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
@ -337,6 +374,11 @@ TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
|
||||
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
|
||||
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
|
||||
auto tensor = torch::tensor({true, false, true});
|
||||
ASSERT_TRUE(tensor.is_variable());
|
||||
@ -409,23 +451,29 @@ TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
|
||||
void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
{
|
||||
auto tensor = torch::tensor({{1.0, 2.0}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kDouble).view(tensor.sizes())));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes())));
|
||||
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes())));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
|
||||
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
|
||||
{
|
||||
auto tensor = torch::tensor({{true, false}});
|
||||
@ -473,11 +521,11 @@ TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Double");
|
||||
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Float");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
|
||||
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Float");
|
||||
}
|
||||
{
|
||||
ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}),
|
||||
@ -489,83 +537,101 @@ TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
|
||||
{
|
||||
void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
auto tensor = torch::tensor(
|
||||
{{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}},
|
||||
torch::dtype(torch::kDouble).device(torch::kCUDA));
|
||||
torch::dtype(default_dtype).device(torch::kCUDA));
|
||||
ASSERT_TRUE(tensor.device().is_cuda());
|
||||
ASSERT_EQ(tensor.dtype(), torch::kDouble);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
|
||||
ASSERT_TRUE(torch::allclose(
|
||||
tensor,
|
||||
torch::arange(1, 10, torch::kDouble).view(tensor.sizes()).to(torch::kCUDA)));
|
||||
torch::arange(1, 10, default_dtype).view(tensor.sizes()).to(torch::kCUDA)));
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
|
||||
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
|
||||
void test_TorchTensorCtorZeroSizedDim_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
{
|
||||
auto tensor = torch::tensor({});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{}, {}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{}, {}}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{}}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{}}}}}}}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{}}}}, {{{{}}}}}}}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 2, 1, 1, 1, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}});
|
||||
ASSERT_EQ(tensor.numel(), 0);
|
||||
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 0}));
|
||||
ASSERT_EQ(tensor.dtype(), torch::kFloat);
|
||||
ASSERT_EQ(tensor.dtype(), default_dtype);
|
||||
ASSERT_FALSE(tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
|
||||
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
}
|
||||
|
||||
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), torch::kDouble);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), torch::kDouble);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), torch::kDouble);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), torch::kDouble);
|
||||
|
||||
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kFloat);
|
||||
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, Arange) {
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/api/support.h>
|
||||
|
||||
#include <torch/torch.h>
|
||||
|
||||
@ -6,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
using namespace at;
|
||||
using namespace torch::test;
|
||||
|
||||
// A macro so we don't lose location information when an assertion fails.
|
||||
#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
|
||||
@ -119,29 +121,25 @@ TEST(DeviceTest, ParsesCorrectlyFromString) {
|
||||
}
|
||||
}
|
||||
|
||||
struct DefaultDtypeTest : ::testing::Test {
|
||||
DefaultDtypeTest() {
|
||||
set_default_dtype(caffe2::TypeMeta::Make<float>());
|
||||
}
|
||||
~DefaultDtypeTest() override {
|
||||
set_default_dtype(caffe2::TypeMeta::Make<float>());
|
||||
}
|
||||
};
|
||||
TEST(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(kFloat);
|
||||
|
||||
TEST_F(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
|
||||
ASSERT_EQ(at::get_default_dtype(), kFloat);
|
||||
set_default_dtype(caffe2::TypeMeta::Make<int>());
|
||||
ASSERT_EQ(at::get_default_dtype(), kInt);
|
||||
}
|
||||
|
||||
TEST_F(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
|
||||
TEST(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
|
||||
AutoDefaultDtypeMode dtype_mode(kFloat);
|
||||
|
||||
set_default_dtype(caffe2::TypeMeta::Make<int>());
|
||||
ASSERT_EQ(at::get_default_dtype(), kInt);
|
||||
TensorOptions options;
|
||||
ASSERT_EQ(options.dtype(), kInt);
|
||||
}
|
||||
|
||||
TEST_F(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
|
||||
TEST(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(kFloat);
|
||||
set_default_dtype(caffe2::TypeMeta::Make<int>());
|
||||
{
|
||||
auto tensor = torch::ones(5);
|
||||
|
@ -24,17 +24,22 @@ inline std::ostream& operator<<(std::ostream& stream, c10::BFloat16 value) {
|
||||
}
|
||||
|
||||
inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
|
||||
// NOTE: the dtype computation in this function only takes effect when the user passes
|
||||
// an integer literal / floating-point literal or a braced-init-list to `torch::tensor`
|
||||
// constructor. It doesn't affect `torch::tensor(at::ArrayRef<T>)` and `torch::tensor(std::vector<T>)`
|
||||
// as the specified dtype `T` is always respected.
|
||||
if (scalar_type == at::kInt || scalar_type == at::kLong) {
|
||||
// In C++, an integer literal without suffix (e.g. `1` instead of `1u`) can be one of
|
||||
// `int` / `long int` / `long long int` types. When we find that `scalar_type` is one
|
||||
// of those types, we always use `torch.int64` type, because In Python `torch.tensor(1)`
|
||||
// always gives a tensor of `torch.int64` dtype.
|
||||
//
|
||||
// Note that this dtype computation only takes effect when the user passes an integer
|
||||
// literal or a braced-init-list to `torch::tensor` constructor. It doesn't affect
|
||||
// `torch::tensor(at::ArrayRef<T>)` and `torch::tensor(std::vector<T>)` as the specified
|
||||
// dtype `T` is always respected.
|
||||
return at::kLong;
|
||||
} else if (scalar_type == at::kDouble) {
|
||||
// When `scalar_type == at::kDouble`, we know that the user is passing in
|
||||
// a floating-point literal without specifying its type (e.g. `1.0` instead of `1.0f`).
|
||||
// In Python, the dtype of `torch.tensor(1.0)` depends on the value of
|
||||
// `torch.get_default_dtype()`, and we should do the same for C++ `torch::tensor(1.0)`.
|
||||
return at::typeMetaToScalarType(at::get_default_dtype());
|
||||
} else {
|
||||
return scalar_type;
|
||||
}
|
||||
@ -94,7 +99,9 @@ struct TensorDataContainer {
|
||||
// the innermost `TensorDataContainer`.
|
||||
TensorDataContainer() :
|
||||
sizes_({0}),
|
||||
scalar_type_(at::kFloat),
|
||||
// NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. `torch.tensor([[], []])`)
|
||||
// depends on the value of `torch.get_default_dtype()`, and we should do the same for the C++ equivalent.
|
||||
scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
|
||||
type_(TensorDataContainerType::InitList) {}
|
||||
#define TENSOR(T, S) \
|
||||
TensorDataContainer(T value) : \
|
||||
|
Reference in New Issue
Block a user