mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Migrate ScalarType to headeronly (#159416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159416 Approved by: https://github.com/albanD ghstack dependencies: #159415, #159411
This commit is contained in:
committed by
PyTorch MergeBot
parent
2a286cbdf4
commit
1371a98b0e
@ -1,5 +1,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
@ -149,3 +150,60 @@ TEST(TestDtype, TestQuintsQintsAndBits) {
|
||||
auto i = torch::headeronly::bits8(2);
|
||||
auto j = torch::headeronly::bits16(6);
|
||||
}
|
||||
|
||||
TEST(TestDtype, TestScalarType) {
|
||||
using torch::headeronly::ScalarType;
|
||||
constexpr ScalarType expected_scalar_types[] = {
|
||||
ScalarType::Byte,
|
||||
ScalarType::Char,
|
||||
ScalarType::Short,
|
||||
ScalarType::Int,
|
||||
ScalarType::Long,
|
||||
ScalarType::Half,
|
||||
ScalarType::Float,
|
||||
ScalarType::Double,
|
||||
ScalarType::ComplexHalf,
|
||||
ScalarType::ComplexFloat,
|
||||
ScalarType::ComplexDouble,
|
||||
ScalarType::Bool,
|
||||
ScalarType::QInt8,
|
||||
ScalarType::QUInt8,
|
||||
ScalarType::QInt32,
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::QUInt4x2,
|
||||
ScalarType::QUInt2x4,
|
||||
ScalarType::Bits1x8,
|
||||
ScalarType::Bits2x4,
|
||||
ScalarType::Bits4x2,
|
||||
ScalarType::Bits8,
|
||||
ScalarType::Bits16,
|
||||
ScalarType::Float8_e5m2,
|
||||
ScalarType::Float8_e4m3fn,
|
||||
ScalarType::Float8_e5m2fnuz,
|
||||
ScalarType::Float8_e4m3fnuz,
|
||||
ScalarType::UInt16,
|
||||
ScalarType::UInt32,
|
||||
ScalarType::UInt64,
|
||||
ScalarType::UInt1,
|
||||
ScalarType::UInt2,
|
||||
ScalarType::UInt3,
|
||||
ScalarType::UInt4,
|
||||
ScalarType::UInt5,
|
||||
ScalarType::UInt6,
|
||||
ScalarType::UInt7,
|
||||
ScalarType::Int1,
|
||||
ScalarType::Int2,
|
||||
ScalarType::Int3,
|
||||
ScalarType::Int4,
|
||||
ScalarType::Int5,
|
||||
ScalarType::Int6,
|
||||
ScalarType::Int7,
|
||||
ScalarType::Float8_e8m0fnu,
|
||||
ScalarType::Float4_e2m1fn_x2,
|
||||
ScalarType::Undefined,
|
||||
};
|
||||
for (int8_t i = 0; i < static_cast<int8_t>(torch::headeronly::NumScalarTypes);
|
||||
i++) {
|
||||
EXPECT_EQ(static_cast<ScalarType>(i), expected_scalar_types[i]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user