[Reland] Migrate ScalarType to headeronly (#159911)

The non ghstack version of #159416, to make sure we don't get reverted again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159911
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2025-08-06 07:36:37 +00:00
committed by PyTorch MergeBot
parent e9d27aa8fd
commit 1690c0c3a0
6 changed files with 171 additions and 73 deletions

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
@ -149,3 +150,60 @@ TEST(TestDtype, TestQuintsQintsAndBits) {
auto i = torch::headeronly::bits8(2);
auto j = torch::headeronly::bits16(6);
}
TEST(TestDtype, TestScalarType) {
using torch::headeronly::ScalarType;
constexpr ScalarType expected_scalar_types[] = {
ScalarType::Byte,
ScalarType::Char,
ScalarType::Short,
ScalarType::Int,
ScalarType::Long,
ScalarType::Half,
ScalarType::Float,
ScalarType::Double,
ScalarType::ComplexHalf,
ScalarType::ComplexFloat,
ScalarType::ComplexDouble,
ScalarType::Bool,
ScalarType::QInt8,
ScalarType::QUInt8,
ScalarType::QInt32,
ScalarType::BFloat16,
ScalarType::QUInt4x2,
ScalarType::QUInt2x4,
ScalarType::Bits1x8,
ScalarType::Bits2x4,
ScalarType::Bits4x2,
ScalarType::Bits8,
ScalarType::Bits16,
ScalarType::Float8_e5m2,
ScalarType::Float8_e4m3fn,
ScalarType::Float8_e5m2fnuz,
ScalarType::Float8_e4m3fnuz,
ScalarType::UInt16,
ScalarType::UInt32,
ScalarType::UInt64,
ScalarType::UInt1,
ScalarType::UInt2,
ScalarType::UInt3,
ScalarType::UInt4,
ScalarType::UInt5,
ScalarType::UInt6,
ScalarType::UInt7,
ScalarType::Int1,
ScalarType::Int2,
ScalarType::Int3,
ScalarType::Int4,
ScalarType::Int5,
ScalarType::Int6,
ScalarType::Int7,
ScalarType::Float8_e8m0fnu,
ScalarType::Float4_e2m1fn_x2,
ScalarType::Undefined,
};
for (int8_t i = 0; i < static_cast<int8_t>(torch::headeronly::NumScalarTypes);
i++) {
EXPECT_EQ(static_cast<ScalarType>(i), expected_scalar_types[i]);
}
}