[Reland] Migrate ScalarType to headeronly (#159911)

The non ghstack version of #159416, to make sure we don't get reverted again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159911
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2025-08-06 07:36:37 +00:00
committed by PyTorch MergeBot
parent e9d27aa8fd
commit 1690c0c3a0
6 changed files with 171 additions and 73 deletions

View File

@ -19,25 +19,16 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <ostream>
#include <type_traits>
#include <unordered_map>
#include <torch/headeronly/core/ScalarType.h>
namespace c10 {
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// dummy struct for int1 to int7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_int1_7_t {};
// For the macros below:
// [dtype Macros note] For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
@ -57,56 +48,6 @@ struct dummy_int1_7_t {};
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
@ -152,17 +93,6 @@ struct dummy_int1_7_t {};
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {
// These are used to map ScalarTypes to C++ types.

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
@ -149,3 +150,60 @@ TEST(TestDtype, TestQuintsQintsAndBits) {
auto i = torch::headeronly::bits8(2);
auto j = torch::headeronly::bits16(6);
}
TEST(TestDtype, TestScalarType) {
using torch::headeronly::ScalarType;
constexpr ScalarType expected_scalar_types[] = {
ScalarType::Byte,
ScalarType::Char,
ScalarType::Short,
ScalarType::Int,
ScalarType::Long,
ScalarType::Half,
ScalarType::Float,
ScalarType::Double,
ScalarType::ComplexHalf,
ScalarType::ComplexFloat,
ScalarType::ComplexDouble,
ScalarType::Bool,
ScalarType::QInt8,
ScalarType::QUInt8,
ScalarType::QInt32,
ScalarType::BFloat16,
ScalarType::QUInt4x2,
ScalarType::QUInt2x4,
ScalarType::Bits1x8,
ScalarType::Bits2x4,
ScalarType::Bits4x2,
ScalarType::Bits8,
ScalarType::Bits16,
ScalarType::Float8_e5m2,
ScalarType::Float8_e4m3fn,
ScalarType::Float8_e5m2fnuz,
ScalarType::Float8_e4m3fnuz,
ScalarType::UInt16,
ScalarType::UInt32,
ScalarType::UInt64,
ScalarType::UInt1,
ScalarType::UInt2,
ScalarType::UInt3,
ScalarType::UInt4,
ScalarType::UInt5,
ScalarType::UInt6,
ScalarType::UInt7,
ScalarType::Int1,
ScalarType::Int2,
ScalarType::Int3,
ScalarType::Int4,
ScalarType::Int5,
ScalarType::Int6,
ScalarType::Int7,
ScalarType::Float8_e8m0fnu,
ScalarType::Float4_e2m1fn_x2,
ScalarType::Undefined,
};
for (int8_t i = 0; i < static_cast<int8_t>(torch::headeronly::NumScalarTypes);
i++) {
EXPECT_EQ(static_cast<ScalarType>(i), expected_scalar_types[i]);
}
}

View File

@ -94,3 +94,8 @@ bits2x4
bits4x2
bits8
bits16
# torch/headeronly/core/ScalarType.h
NumScalarTypes
ScalarType
# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType

View File

@ -20,6 +20,7 @@ configure_file(
file(GLOB HEADERONLY_HEADERS
*.h
core/**/*.h
cpu/**/*.h
macros/*.h
util/*.h

View File

@ -0,0 +1,103 @@
#pragma once
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
#include <torch/headeronly/util/Float8_e5m2.h>
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
#include <torch/headeronly/util/Float8_e8m0fnu.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/bits.h>
#include <torch/headeronly/util/complex.h>
#include <torch/headeronly/util/qint32.h>
#include <torch/headeronly/util/qint8.h>
#include <torch/headeronly/util/quint2x4.h>
#include <torch/headeronly/util/quint4x2.h>
#include <torch/headeronly/util/quint8.h>
#include <cstdint>
namespace c10 {
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// dummy struct for int1 to int7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_int1_7_t {};
// See [dtype Macros note] in c10/core/ScalarType.h regarding macros
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
} // namespace c10
namespace torch::headeronly {
using c10::dummy_int1_7_t;
using c10::dummy_uint1_7_t;
using c10::NumScalarTypes;
using c10::ScalarType;
} // namespace torch::headeronly

View File

@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
public_include_directories = ["../.."],
public_preprocessor_flags = pp_flags,
public_raw_headers = native.glob([
"core/**/*.h",
"cpu/**/*.h",
"macros/*.h",
"util/*.h",