[Reland] Migrate ScalarType to headeronly (#159911)

The non ghstack version of #159416, to make sure we don't get reverted again
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159911
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2025-08-06 07:36:37 +00:00
committed by PyTorch MergeBot
parent e9d27aa8fd
commit 1690c0c3a0
6 changed files with 171 additions and 73 deletions

View File

@ -20,6 +20,7 @@ configure_file(
file(GLOB HEADERONLY_HEADERS
*.h
core/**/*.h
cpu/**/*.h
macros/*.h
util/*.h

View File

@ -0,0 +1,103 @@
#pragma once
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
#include <torch/headeronly/util/Float8_e4m3fn.h>
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
#include <torch/headeronly/util/Float8_e5m2.h>
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
#include <torch/headeronly/util/Float8_e8m0fnu.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/bits.h>
#include <torch/headeronly/util/complex.h>
#include <torch/headeronly/util/qint32.h>
#include <torch/headeronly/util/qint8.h>
#include <torch/headeronly/util/quint2x4.h>
#include <torch/headeronly/util/quint4x2.h>
#include <torch/headeronly/util/quint8.h>
#include <cstdint>
namespace c10 {
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// dummy struct for int1 to int7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_int1_7_t {};
// See [dtype Macros note] in c10/core/ScalarType.h regarding macros
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
} // namespace c10
namespace torch::headeronly {
using c10::dummy_int1_7_t;
using c10::dummy_uint1_7_t;
using c10::NumScalarTypes;
using c10::ScalarType;
} // namespace torch::headeronly

View File

@ -29,6 +29,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
public_include_directories = ["../.."],
public_preprocessor_flags = pp_flags,
public_raw_headers = native.glob([
"core/**/*.h",
"cpu/**/*.h",
"macros/*.h",
"util/*.h",