Use default dtype for torch::tensor(floating_point_values) and torch::tensor(empty braced-init-list) when dtype is not specified (#29632)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29632

This PR is BC-breaking in the following way:

Previously, C++ `torch::tensor` with a floating-point literal with no suffix (e.g. `torch::tensor(1.1)`) or a (nested) braced-init-list of
floating-point literals with no suffix (e.g. `torch::tensor({{1.1, 2.2}})` produces a tensor with dtype `at::kDouble`. After this PR, it produces a tensor with dtype `torch::get_default_dtype()`, matching Python `torch.tensor` behavior.

Test Plan: Imported from OSS

Differential Revision: D18465819

Pulled By: yf225

fbshipit-source-id: 6834fe50335c677bc3832f2a5e9cf8d1ede9f665
This commit is contained in:
Will Feng
2019-11-13 15:14:08 -08:00
committed by Facebook Github Bot
parent 3fb9bbc99b
commit 2bcac59a30
9 changed files with 174 additions and 72 deletions

View File

@ -22,6 +22,7 @@ set(TORCH_API_TEST_SOURCES
${TORCH_API_TEST_DIR}/sequential.cpp
${TORCH_API_TEST_DIR}/serialize.cpp
${TORCH_API_TEST_DIR}/static.cpp
${TORCH_API_TEST_DIR}/support.cpp
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
${TORCH_API_TEST_DIR}/tensor_options_cuda.cpp
${TORCH_API_TEST_DIR}/tensor_options.cpp

View File

@ -462,7 +462,7 @@ TEST_F(FunctionalTest, GridSample) {
TEST_F(FunctionalTest, AffineGrid) {
{
// 2D affine.
auto theta = torch::arange(1, 13, torch::kDouble)
auto theta = torch::arange(1., 13)
.view(std::vector<int64_t>({2, 2, 3}));
auto size = std::vector<int64_t>({2, 3, 2, 2});
auto align_corners = true;
@ -480,7 +480,7 @@ TEST_F(FunctionalTest, AffineGrid) {
}
{
// 3D affine.
auto theta = torch::arange(1, 13, torch::kDouble)
auto theta = torch::arange(1., 13)
.view(std::vector<int64_t>({1, 3, 4}));
auto size = std::vector<int64_t>({1, 1, 3, 2, 2});
auto align_corners = true;

View File

@ -29,8 +29,9 @@ void check_exact_values(
}
for (size_t p = 0; p < layerParameters.size(0); p++) {
auto tensor = layerParameters[p];
auto expectedTensor = expectedLayerParameters[p];
// Always compare using double dtype, regardless of the original dtype of the tensors
auto tensor = layerParameters[p].to(torch::kFloat64);
auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor

View File

@ -92,16 +92,16 @@ void check_exact_values(
assign_parameter(
parameters,
"0.weight",
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}));
torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64));
assign_parameter(
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}));
parameters, "0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64));
assign_parameter(
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}));
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}));
parameters, "2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64));
assign_parameter(parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
auto optimizer = OptimizerClass(parameters.values(), options);
torch::Tensor input =
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}).reshape({3, 2});
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64).reshape({3, 2});
for (size_t i = 0; i < kIterations; ++i) {
optimizer.zero_grad();
@ -116,8 +116,9 @@ void check_exact_values(
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
ASSERT_TRUE(parameters[p]->defined());
auto computed = parameters[p]->flatten();
auto expected = expected_parameters.at(i / kSampleEvery).at(p);
// Always compare using double dtype, regardless of the original dtype of the tensors
auto computed = parameters[p]->flatten().to(torch::kFloat64);
auto expected = expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "Iteration " << i << ": " << computed
<< " != " << expected << " (parameter " << p << ")"

9
test/cpp/api/support.cpp Normal file
View File

@ -0,0 +1,9 @@
#include <test/cpp/api/support.h>
namespace torch {
namespace test {
std::mutex AutoDefaultDtypeMode::default_dtype_mutex;
} // namespace test
} // namespace torch

View File

@ -61,5 +61,24 @@ inline int count_substr_occurrences(const std::string& str, const std::string& s
return count;
}
// A RAII, thread local (!) guard that changes default dtype upon
// construction, and sets it back to the original dtype upon destruction.
//
// Usage of this guard is synchronized across threads, so that at any given time,
// only one guard can take effect.
struct AutoDefaultDtypeMode {
static std::mutex default_dtype_mutex;
AutoDefaultDtypeMode(c10::ScalarType default_dtype) : prev_default_dtype(torch::typeMetaToScalarType(torch::get_default_dtype())) {
default_dtype_mutex.lock();
torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
}
~AutoDefaultDtypeMode() {
default_dtype_mutex.unlock();
torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
}
c10::ScalarType prev_default_dtype;
};
} // namespace test
} // namespace torch

View File

@ -9,14 +9,16 @@
#include <test/cpp/common/support.h>
template <typename T>
bool exactly_equal(at::Tensor left, T right) {
return left.item<T>() == right;
}
using namespace torch::test;
template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(left.item<T>() - right) < tolerance;
bool exactly_equal(at::Tensor left, T right) {
return left.item<T>() == right;
}
template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(left.item<T>() - right) < tolerance;
}
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
@ -228,7 +230,9 @@ TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
ASSERT_EQ(tensor.item<int32_t>(), 123);
}
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
void test_TorchTensorCtorScalarFloatingType_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor(123.456f);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
@ -238,16 +242,21 @@ TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
tensor = torch::tensor(123.456);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor, 123.456));
tensor = torch::tensor({123.456});
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 123.456));
}
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorScalarFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
auto tensor = torch::tensor(true);
ASSERT_EQ(tensor.numel(), 1);
@ -309,12 +318,40 @@ TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
ASSERT_TRUE(exactly_equal(tensor[2], 3));
}
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor({1.5, 2.25, 3.125});
ASSERT_TRUE(tensor.is_variable());
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kDouble);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor({1.5f, 2.25f, 3.125f});
ASSERT_TRUE(tensor.is_variable());
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kFloat);
ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
tensor = torch::tensor(at::ArrayRef<float>({1.5, 2.25, 3.125}));
ASSERT_TRUE(tensor.is_variable());
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.dtype(), at::kFloat);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
tensor = torch::tensor(std::vector<float>({1.5, 2.25, 3.125}));
ASSERT_TRUE(tensor.is_variable());
ASSERT_EQ(tensor.numel(), 3);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
ASSERT_EQ(tensor.dtype(), at::kFloat);
ASSERT_TRUE(almost_equal(tensor[0], 1.5));
ASSERT_TRUE(almost_equal(tensor[1], 2.25));
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
@ -337,6 +374,11 @@ TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
ASSERT_TRUE(almost_equal(tensor[2], 3.125));
}
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorSingleDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
auto tensor = torch::tensor({true, false, true});
ASSERT_TRUE(tensor.is_variable());
@ -409,23 +451,29 @@ TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
}
}
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
{
auto tensor = torch::tensor({{1.0, 2.0}});
ASSERT_EQ(tensor.dtype(), torch::kDouble);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, torch::kDouble).view(tensor.sizes())));
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}});
ASSERT_EQ(tensor.dtype(), torch::kDouble);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes())));
ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes())));
ASSERT_FALSE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorMultiDimFloatingType_expected_dtype(/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
{
auto tensor = torch::tensor({{true, false}});
@ -473,11 +521,11 @@ TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
}
{
ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}),
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Double");
"Expected all elements of the tensor to have the same scalar type: Long, but got element of scalar type: Float");
}
{
ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
"Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Float");
}
{
ASSERT_THROWS_WITH(torch::tensor({{{true}, {2}}}),
@ -489,83 +537,101 @@ TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
}
}
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
{
auto tensor = torch::tensor(
{{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}},
torch::dtype(torch::kDouble).device(torch::kCUDA));
ASSERT_TRUE(tensor.device().is_cuda());
ASSERT_EQ(tensor.dtype(), torch::kDouble);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
ASSERT_TRUE(torch::allclose(
tensor,
torch::arange(1, 10, torch::kDouble).view(tensor.sizes()).to(torch::kCUDA)));
ASSERT_FALSE(tensor.requires_grad());
}
void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
auto tensor = torch::tensor(
{{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}},
torch::dtype(default_dtype).device(torch::kCUDA));
ASSERT_TRUE(tensor.device().is_cuda());
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
ASSERT_TRUE(torch::allclose(
tensor,
torch::arange(1, 10, default_dtype).view(tensor.sizes()).to(torch::kCUDA)));
ASSERT_FALSE(tensor.requires_grad());
}
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorMultiDim_CUDA_expected_dtype(/*default_dtype=*/torch::kDouble);
}
void test_TorchTensorCtorZeroSizedDim_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
{
auto tensor = torch::tensor({});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{}, {}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{}, {}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{}}}}, {{{{}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 2, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
{
auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}});
ASSERT_EQ(tensor.numel(), 0);
ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 0}));
ASSERT_EQ(tensor.dtype(), torch::kFloat);
ASSERT_EQ(tensor.dtype(), default_dtype);
ASSERT_FALSE(tensor.requires_grad());
}
}
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorZeroSizedDim_expected_dtype(/*default_dtype=*/torch::kDouble);
}
void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType default_dtype) {
AutoDefaultDtypeMode dtype_mode(default_dtype);
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype);
}
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), torch::kDouble);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), torch::kDouble);
ASSERT_EQ(torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
ASSERT_EQ(torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), torch::kDouble);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), torch::kDouble);
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kFloat);
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble);
}
TEST(TensorTest, Arange) {

View File

@ -1,4 +1,5 @@
#include <gtest/gtest.h>
#include <test/cpp/api/support.h>
#include <torch/torch.h>
@ -6,6 +7,7 @@
#include <vector>
using namespace at;
using namespace torch::test;
// A macro so we don't lose location information when an assertion fails.
#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
@ -119,29 +121,25 @@ TEST(DeviceTest, ParsesCorrectlyFromString) {
}
}
struct DefaultDtypeTest : ::testing::Test {
DefaultDtypeTest() {
set_default_dtype(caffe2::TypeMeta::Make<float>());
}
~DefaultDtypeTest() override {
set_default_dtype(caffe2::TypeMeta::Make<float>());
}
};
TEST(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
AutoDefaultDtypeMode dtype_mode(kFloat);
TEST_F(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
ASSERT_EQ(at::get_default_dtype(), kFloat);
set_default_dtype(caffe2::TypeMeta::Make<int>());
ASSERT_EQ(at::get_default_dtype(), kInt);
}
TEST_F(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
TEST(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
AutoDefaultDtypeMode dtype_mode(kFloat);
set_default_dtype(caffe2::TypeMeta::Make<int>());
ASSERT_EQ(at::get_default_dtype(), kInt);
TensorOptions options;
ASSERT_EQ(options.dtype(), kInt);
}
TEST_F(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
TEST(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
AutoDefaultDtypeMode dtype_mode(kFloat);
set_default_dtype(caffe2::TypeMeta::Make<int>());
{
auto tensor = torch::ones(5);

View File

@ -24,17 +24,22 @@ inline std::ostream& operator<<(std::ostream& stream, c10::BFloat16 value) {
}
inline c10::ScalarType compute_desired_dtype(c10::ScalarType scalar_type) {
// NOTE: the dtype computation in this function only takes effect when the user passes
// an integer literal / floating-point literal or a braced-init-list to `torch::tensor`
// constructor. It doesn't affect `torch::tensor(at::ArrayRef<T>)` and `torch::tensor(std::vector<T>)`
// as the specified dtype `T` is always respected.
if (scalar_type == at::kInt || scalar_type == at::kLong) {
// In C++, an integer literal without suffix (e.g. `1` instead of `1u`) can be one of
// `int` / `long int` / `long long int` types. When we find that `scalar_type` is one
// of those types, we always use `torch.int64` type, because In Python `torch.tensor(1)`
// always gives a tensor of `torch.int64` dtype.
//
// Note that this dtype computation only takes effect when the user passes an integer
// literal or a braced-init-list to `torch::tensor` constructor. It doesn't affect
// `torch::tensor(at::ArrayRef<T>)` and `torch::tensor(std::vector<T>)` as the specified
// dtype `T` is always respected.
return at::kLong;
} else if (scalar_type == at::kDouble) {
// When `scalar_type == at::kDouble`, we know that the user is passing in
// a floating-point literal without specifying its type (e.g. `1.0` instead of `1.0f`).
// In Python, the dtype of `torch.tensor(1.0)` depends on the value of
// `torch.get_default_dtype()`, and we should do the same for C++ `torch::tensor(1.0)`.
return at::typeMetaToScalarType(at::get_default_dtype());
} else {
return scalar_type;
}
@ -94,7 +99,9 @@ struct TensorDataContainer {
// the innermost `TensorDataContainer`.
TensorDataContainer() :
sizes_({0}),
scalar_type_(at::kFloat),
// NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. `torch.tensor([[], []])`)
// depends on the value of `torch.get_default_dtype()`, and we should do the same for the C++ equivalent.
scalar_type_(at::typeMetaToScalarType(at::get_default_dtype())),
type_(TensorDataContainerType::InitList) {}
#define TENSOR(T, S) \
TensorDataContainer(T value) : \