mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
Add more tests for torch::arange
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29689 Test Plan: Imported from OSS Differential Revision: D18465818 Pulled By: yf225 fbshipit-source-id: 0cf0aaa7febcf4318abdaae7d17a43ab3acde017
This commit is contained in:
committed by
Facebook Github Bot
parent
2bcac59a30
commit
65f691f2c2
@ -634,11 +634,19 @@ TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
|
||||
test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(/*default_dtype=*/torch::kDouble);
|
||||
}
|
||||
|
||||
void test_Arange_expected_dtype(c10::ScalarType default_dtype) {
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
ASSERT_EQ(torch::arange(0., 5).dtype(), default_dtype);
|
||||
}
|
||||
|
||||
TEST(TensorTest, Arange) {
|
||||
{ // Test #1
|
||||
{
|
||||
auto x = torch::arange(0, 5);
|
||||
TORCH_INTERNAL_ASSERT(x.dtype() == at::ScalarType::Long);
|
||||
ASSERT_EQ(x.dtype(), torch::kLong);
|
||||
}
|
||||
test_Arange_expected_dtype(torch::kFloat);
|
||||
test_Arange_expected_dtype(torch::kDouble);
|
||||
}
|
||||
|
||||
TEST(TensorTest, PrettyPrintTensorDataContainer) {
|
||||
|
Reference in New Issue
Block a user