mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Migrate ScalarType to headeronly (#159416)"
This reverts commit 1371a98b0e727f8a8916dd473b6dd0cff78c0449. Reverted https://github.com/pytorch/pytorch/pull/159416 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D79452481 ([comment](https://github.com/pytorch/pytorch/pull/159416#issuecomment-3152138508))
This commit is contained in:
@ -19,16 +19,25 @@
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// [dtype Macros note] For the macros below:
|
||||
// dummy struct for uint1 to uint7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_uint1_7_t {};
|
||||
|
||||
// dummy struct for int1 to int7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_int1_7_t {};
|
||||
|
||||
// For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
@ -48,6 +57,56 @@ namespace c10 {
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
||||
_(uint8_t, Byte) /* 0 */ \
|
||||
_(int8_t, Char) /* 1 */ \
|
||||
_(int16_t, Short) /* 2 */ \
|
||||
_(int, Int) /* 3 */ \
|
||||
_(int64_t, Long) /* 4 */ \
|
||||
_(at::Half, Half) /* 5 */ \
|
||||
_(float, Float) /* 6 */ \
|
||||
_(double, Double) /* 7 */ \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
||||
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
||||
_(bool, Bool) /* 11 */ \
|
||||
_(c10::qint8, QInt8) /* 12 */ \
|
||||
_(c10::quint8, QUInt8) /* 13 */ \
|
||||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||
_(c10::bits8, Bits8) /* 21 */ \
|
||||
_(c10::bits16, Bits16) /* 22 */ \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||
_(uint16_t, UInt16) /* 27 */ \
|
||||
_(uint32_t, UInt32) /* 28 */ \
|
||||
_(uint64_t, UInt64) /* 29 */ \
|
||||
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
|
||||
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
|
||||
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
|
||||
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
|
||||
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
|
||||
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
|
||||
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
|
||||
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
|
||||
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
|
||||
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
|
||||
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
|
||||
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
@ -93,6 +152,17 @@ namespace c10 {
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
||||
#undef DEFINE_ENUM_ST_ENUM_VAL_
|
||||
Undefined,
|
||||
NumOptions
|
||||
};
|
||||
|
||||
constexpr uint16_t NumScalarTypes =
|
||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
@ -1,6 +1,5 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
@ -150,60 +149,3 @@ TEST(TestDtype, TestQuintsQintsAndBits) {
|
||||
auto i = torch::headeronly::bits8(2);
|
||||
auto j = torch::headeronly::bits16(6);
|
||||
}
|
||||
|
||||
TEST(TestDtype, TestScalarType) {
|
||||
using torch::headeronly::ScalarType;
|
||||
constexpr ScalarType expected_scalar_types[] = {
|
||||
ScalarType::Byte,
|
||||
ScalarType::Char,
|
||||
ScalarType::Short,
|
||||
ScalarType::Int,
|
||||
ScalarType::Long,
|
||||
ScalarType::Half,
|
||||
ScalarType::Float,
|
||||
ScalarType::Double,
|
||||
ScalarType::ComplexHalf,
|
||||
ScalarType::ComplexFloat,
|
||||
ScalarType::ComplexDouble,
|
||||
ScalarType::Bool,
|
||||
ScalarType::QInt8,
|
||||
ScalarType::QUInt8,
|
||||
ScalarType::QInt32,
|
||||
ScalarType::BFloat16,
|
||||
ScalarType::QUInt4x2,
|
||||
ScalarType::QUInt2x4,
|
||||
ScalarType::Bits1x8,
|
||||
ScalarType::Bits2x4,
|
||||
ScalarType::Bits4x2,
|
||||
ScalarType::Bits8,
|
||||
ScalarType::Bits16,
|
||||
ScalarType::Float8_e5m2,
|
||||
ScalarType::Float8_e4m3fn,
|
||||
ScalarType::Float8_e5m2fnuz,
|
||||
ScalarType::Float8_e4m3fnuz,
|
||||
ScalarType::UInt16,
|
||||
ScalarType::UInt32,
|
||||
ScalarType::UInt64,
|
||||
ScalarType::UInt1,
|
||||
ScalarType::UInt2,
|
||||
ScalarType::UInt3,
|
||||
ScalarType::UInt4,
|
||||
ScalarType::UInt5,
|
||||
ScalarType::UInt6,
|
||||
ScalarType::UInt7,
|
||||
ScalarType::Int1,
|
||||
ScalarType::Int2,
|
||||
ScalarType::Int3,
|
||||
ScalarType::Int4,
|
||||
ScalarType::Int5,
|
||||
ScalarType::Int6,
|
||||
ScalarType::Int7,
|
||||
ScalarType::Float8_e8m0fnu,
|
||||
ScalarType::Float4_e2m1fn_x2,
|
||||
ScalarType::Undefined,
|
||||
};
|
||||
for (int8_t i = 0; i < static_cast<int8_t>(torch::headeronly::NumScalarTypes);
|
||||
i++) {
|
||||
EXPECT_EQ(static_cast<ScalarType>(i), expected_scalar_types[i]);
|
||||
}
|
||||
}
|
||||
|
@ -91,8 +91,3 @@ bits2x4
|
||||
bits4x2
|
||||
bits8
|
||||
bits16
|
||||
|
||||
# torch/headeronly/core/ScalarType.h
|
||||
NumScalarTypes
|
||||
ScalarType
|
||||
# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType
|
||||
|
@ -20,7 +20,6 @@ configure_file(
|
||||
|
||||
file(GLOB HEADERONLY_HEADERS
|
||||
*.h
|
||||
core/**/*.h
|
||||
cpu/**/*.h
|
||||
macros/*.h
|
||||
util/*.h
|
||||
|
@ -1,103 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
#include <torch/headeronly/util/bits.h>
|
||||
#include <torch/headeronly/util/complex.h>
|
||||
#include <torch/headeronly/util/qint32.h>
|
||||
#include <torch/headeronly/util/qint8.h>
|
||||
#include <torch/headeronly/util/quint2x4.h>
|
||||
#include <torch/headeronly/util/quint4x2.h>
|
||||
#include <torch/headeronly/util/quint8.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// dummy struct for uint1 to uint7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_uint1_7_t {};
|
||||
|
||||
// dummy struct for int1 to int7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_int1_7_t {};
|
||||
|
||||
// See [dtype Macros note] in c10/core/ScalarType.h regarding macros
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
||||
_(uint8_t, Byte) /* 0 */ \
|
||||
_(int8_t, Char) /* 1 */ \
|
||||
_(int16_t, Short) /* 2 */ \
|
||||
_(int, Int) /* 3 */ \
|
||||
_(int64_t, Long) /* 4 */ \
|
||||
_(at::Half, Half) /* 5 */ \
|
||||
_(float, Float) /* 6 */ \
|
||||
_(double, Double) /* 7 */ \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
||||
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
||||
_(bool, Bool) /* 11 */ \
|
||||
_(c10::qint8, QInt8) /* 12 */ \
|
||||
_(c10::quint8, QUInt8) /* 13 */ \
|
||||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||
_(c10::bits8, Bits8) /* 21 */ \
|
||||
_(c10::bits16, Bits16) /* 22 */ \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||
_(uint16_t, UInt16) /* 27 */ \
|
||||
_(uint32_t, UInt32) /* 28 */ \
|
||||
_(uint64_t, UInt64) /* 29 */ \
|
||||
_(dummy_uint1_7_t<1>, UInt1) /* 30 */ \
|
||||
_(dummy_uint1_7_t<2>, UInt2) /* 31 */ \
|
||||
_(dummy_uint1_7_t<3>, UInt3) /* 32 */ \
|
||||
_(dummy_uint1_7_t<4>, UInt4) /* 33 */ \
|
||||
_(dummy_uint1_7_t<5>, UInt5) /* 34 */ \
|
||||
_(dummy_uint1_7_t<6>, UInt6) /* 35 */ \
|
||||
_(dummy_uint1_7_t<7>, UInt7) /* 36 */ \
|
||||
_(dummy_int1_7_t<1>, Int1) /* 37 */ \
|
||||
_(dummy_int1_7_t<2>, Int2) /* 38 */ \
|
||||
_(dummy_int1_7_t<3>, Int3) /* 39 */ \
|
||||
_(dummy_int1_7_t<4>, Int4) /* 40 */ \
|
||||
_(dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||
_(dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||
_(dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
||||
#undef DEFINE_ENUM_ST_ENUM_VAL_
|
||||
Undefined,
|
||||
NumOptions
|
||||
};
|
||||
|
||||
constexpr uint16_t NumScalarTypes =
|
||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::headeronly {
|
||||
using c10::dummy_int1_7_t;
|
||||
using c10::dummy_uint1_7_t;
|
||||
using c10::NumScalarTypes;
|
||||
using c10::ScalarType;
|
||||
} // namespace torch::headeronly
|
@ -29,7 +29,6 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
|
||||
public_include_directories = ["../.."],
|
||||
public_preprocessor_flags = pp_flags,
|
||||
public_raw_headers = native.glob([
|
||||
"core/**/*.h",
|
||||
"cpu/**/*.h",
|
||||
"macros/*.h",
|
||||
"util/*.h",
|
||||
|
Reference in New Issue
Block a user