mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Migrate ScalarType to headeronly (#159416)"
This reverts commit 1371a98b0e727f8a8916dd473b6dd0cff78c0449. Reverted https://github.com/pytorch/pytorch/pull/159416 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D79452481 ([comment](https://github.com/pytorch/pytorch/pull/159416#issuecomment-3152138508))
This commit is contained in:
@ -19,16 +19,25 @@
|
|||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <torch/headeronly/core/ScalarType.h>
|
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
// [dtype Macros note] For the macros below:
|
// dummy struct for uint1 to uint7, actual functionality
|
||||||
|
// of these dtypes will be implemented in python with Tensor subclass
|
||||||
|
template <unsigned int N>
|
||||||
|
struct dummy_uint1_7_t {};
|
||||||
|
|
||||||
|
// dummy struct for int1 to int7, actual functionality
|
||||||
|
// of these dtypes will be implemented in python with Tensor subclass
|
||||||
|
template <unsigned int N>
|
||||||
|
struct dummy_int1_7_t {};
|
||||||
|
|
||||||
|
// For the macros below:
|
||||||
//
|
//
|
||||||
// For users: If you want to macro some code for all non-QInt scalar types
|
// For users: If you want to macro some code for all non-QInt scalar types
|
||||||
// (i.e. types with complete information, you probably want one of the
|
// (i.e. types with complete information, you probably want one of the
|
||||||
@ -48,6 +57,56 @@ namespace c10 {
|
|||||||
// some old PRs where we added new dtypes (check history of this file) can
|
// some old PRs where we added new dtypes (check history of this file) can
|
||||||
// help give you an idea where to start.
|
// help give you an idea where to start.
|
||||||
|
|
||||||
|
// NB: Order matters for this macro; it is relied upon in
|
||||||
|
// _promoteTypesLookup and the serialization format.
|
||||||
|
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
||||||
|
_(uint8_t, Byte) /* 0 */ \
|
||||||
|
_(int8_t, Char) /* 1 */ \
|
||||||
|
_(int16_t, Short) /* 2 */ \
|
||||||
|
_(int, Int) /* 3 */ \
|
||||||
|
_(int64_t, Long) /* 4 */ \
|
||||||
|
_(at::Half, Half) /* 5 */ \
|
||||||
|
_(float, Float) /* 6 */ \
|
||||||
|
_(double, Double) /* 7 */ \
|
||||||
|
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||||
|
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
||||||
|
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
||||||
|
_(bool, Bool) /* 11 */ \
|
||||||
|
_(c10::qint8, QInt8) /* 12 */ \
|
||||||
|
_(c10::quint8, QUInt8) /* 13 */ \
|
||||||
|
_(c10::qint32, QInt32) /* 14 */ \
|
||||||
|
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||||
|
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||||
|
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||||
|
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||||
|
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||||
|
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||||
|
_(c10::bits8, Bits8) /* 21 */ \
|
||||||
|
_(c10::bits16, Bits16) /* 22 */ \
|
||||||
|
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||||
|
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||||
|
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||||
|
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||||
|
_(uint16_t, UInt16) /* 27 */ \
|
||||||
|
_(uint32_t, UInt32) /* 28 */ \
|
||||||
|
_(uint64_t, UInt64) /* 29 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
|
||||||
|
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
|
||||||
|
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
|
||||||
|
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
|
||||||
|
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
|
||||||
|
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
|
||||||
|
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||||
|
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||||
|
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||||
|
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||||
|
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||||
|
|
||||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||||
// into this macro (and change the name). But beware: convert()
|
// into this macro (and change the name). But beware: convert()
|
||||||
// doesn't work for all the conversions you need...
|
// doesn't work for all the conversions you need...
|
||||||
@ -93,6 +152,17 @@ namespace c10 {
|
|||||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||||
|
|
||||||
|
enum class ScalarType : int8_t {
|
||||||
|
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||||
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
||||||
|
#undef DEFINE_ENUM_ST_ENUM_VAL_
|
||||||
|
Undefined,
|
||||||
|
NumOptions
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr uint16_t NumScalarTypes =
|
||||||
|
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||||
|
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
// These are used to map ScalarTypes to C++ types.
|
// These are used to map ScalarTypes to C++ types.
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
#include <torch/headeronly/core/ScalarType.h>
|
|
||||||
#include <torch/headeronly/util/BFloat16.h>
|
#include <torch/headeronly/util/BFloat16.h>
|
||||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||||
@ -150,60 +149,3 @@ TEST(TestDtype, TestQuintsQintsAndBits) {
|
|||||||
auto i = torch::headeronly::bits8(2);
|
auto i = torch::headeronly::bits8(2);
|
||||||
auto j = torch::headeronly::bits16(6);
|
auto j = torch::headeronly::bits16(6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TestDtype, TestScalarType) {
|
|
||||||
using torch::headeronly::ScalarType;
|
|
||||||
constexpr ScalarType expected_scalar_types[] = {
|
|
||||||
ScalarType::Byte,
|
|
||||||
ScalarType::Char,
|
|
||||||
ScalarType::Short,
|
|
||||||
ScalarType::Int,
|
|
||||||
ScalarType::Long,
|
|
||||||
ScalarType::Half,
|
|
||||||
ScalarType::Float,
|
|
||||||
ScalarType::Double,
|
|
||||||
ScalarType::ComplexHalf,
|
|
||||||
ScalarType::ComplexFloat,
|
|
||||||
ScalarType::ComplexDouble,
|
|
||||||
ScalarType::Bool,
|
|
||||||
ScalarType::QInt8,
|
|
||||||
ScalarType::QUInt8,
|
|
||||||
ScalarType::QInt32,
|
|
||||||
ScalarType::BFloat16,
|
|
||||||
ScalarType::QUInt4x2,
|
|
||||||
ScalarType::QUInt2x4,
|
|
||||||
ScalarType::Bits1x8,
|
|
||||||
ScalarType::Bits2x4,
|
|
||||||
ScalarType::Bits4x2,
|
|
||||||
ScalarType::Bits8,
|
|
||||||
ScalarType::Bits16,
|
|
||||||
ScalarType::Float8_e5m2,
|
|
||||||
ScalarType::Float8_e4m3fn,
|
|
||||||
ScalarType::Float8_e5m2fnuz,
|
|
||||||
ScalarType::Float8_e4m3fnuz,
|
|
||||||
ScalarType::UInt16,
|
|
||||||
ScalarType::UInt32,
|
|
||||||
ScalarType::UInt64,
|
|
||||||
ScalarType::UInt1,
|
|
||||||
ScalarType::UInt2,
|
|
||||||
ScalarType::UInt3,
|
|
||||||
ScalarType::UInt4,
|
|
||||||
ScalarType::UInt5,
|
|
||||||
ScalarType::UInt6,
|
|
||||||
ScalarType::UInt7,
|
|
||||||
ScalarType::Int1,
|
|
||||||
ScalarType::Int2,
|
|
||||||
ScalarType::Int3,
|
|
||||||
ScalarType::Int4,
|
|
||||||
ScalarType::Int5,
|
|
||||||
ScalarType::Int6,
|
|
||||||
ScalarType::Int7,
|
|
||||||
ScalarType::Float8_e8m0fnu,
|
|
||||||
ScalarType::Float4_e2m1fn_x2,
|
|
||||||
ScalarType::Undefined,
|
|
||||||
};
|
|
||||||
for (int8_t i = 0; i < static_cast<int8_t>(torch::headeronly::NumScalarTypes);
|
|
||||||
i++) {
|
|
||||||
EXPECT_EQ(static_cast<ScalarType>(i), expected_scalar_types[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -91,8 +91,3 @@ bits2x4
|
|||||||
bits4x2
|
bits4x2
|
||||||
bits8
|
bits8
|
||||||
bits16
|
bits16
|
||||||
|
|
||||||
# torch/headeronly/core/ScalarType.h
|
|
||||||
NumScalarTypes
|
|
||||||
ScalarType
|
|
||||||
# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType
|
|
||||||
|
@ -20,7 +20,6 @@ configure_file(
|
|||||||
|
|
||||||
file(GLOB HEADERONLY_HEADERS
|
file(GLOB HEADERONLY_HEADERS
|
||||||
*.h
|
*.h
|
||||||
core/**/*.h
|
|
||||||
cpu/**/*.h
|
cpu/**/*.h
|
||||||
macros/*.h
|
macros/*.h
|
||||||
util/*.h
|
util/*.h
|
||||||
|
@ -1,103 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/headeronly/util/BFloat16.h>
|
|
||||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
|
||||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
|
||||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
|
||||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
|
||||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
|
||||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
|
||||||
#include <torch/headeronly/util/Half.h>
|
|
||||||
#include <torch/headeronly/util/bits.h>
|
|
||||||
#include <torch/headeronly/util/complex.h>
|
|
||||||
#include <torch/headeronly/util/qint32.h>
|
|
||||||
#include <torch/headeronly/util/qint8.h>
|
|
||||||
#include <torch/headeronly/util/quint2x4.h>
|
|
||||||
#include <torch/headeronly/util/quint4x2.h>
|
|
||||||
#include <torch/headeronly/util/quint8.h>
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
namespace c10 {
|
|
||||||
|
|
||||||
// dummy struct for uint1 to uint7, actual functionality
|
|
||||||
// of these dtypes will be implemented in python with Tensor subclass
|
|
||||||
template <unsigned int N>
|
|
||||||
struct dummy_uint1_7_t {};
|
|
||||||
|
|
||||||
// dummy struct for int1 to int7, actual functionality
|
|
||||||
// of these dtypes will be implemented in python with Tensor subclass
|
|
||||||
template <unsigned int N>
|
|
||||||
struct dummy_int1_7_t {};
|
|
||||||
|
|
||||||
// See [dtype Macros note] in c10/core/ScalarType.h regarding macros
|
|
||||||
|
|
||||||
// NB: Order matters for this macro; it is relied upon in
|
|
||||||
// _promoteTypesLookup and the serialization format.
|
|
||||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
|
||||||
_(uint8_t, Byte) /* 0 */ \
|
|
||||||
_(int8_t, Char) /* 1 */ \
|
|
||||||
_(int16_t, Short) /* 2 */ \
|
|
||||||
_(int, Int) /* 3 */ \
|
|
||||||
_(int64_t, Long) /* 4 */ \
|
|
||||||
_(at::Half, Half) /* 5 */ \
|
|
||||||
_(float, Float) /* 6 */ \
|
|
||||||
_(double, Double) /* 7 */ \
|
|
||||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
|
||||||
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
|
||||||
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
|
||||||
_(bool, Bool) /* 11 */ \
|
|
||||||
_(c10::qint8, QInt8) /* 12 */ \
|
|
||||||
_(c10::quint8, QUInt8) /* 13 */ \
|
|
||||||
_(c10::qint32, QInt32) /* 14 */ \
|
|
||||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
|
||||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
|
||||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
|
||||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
|
||||||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
|
||||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
|
||||||
_(c10::bits8, Bits8) /* 21 */ \
|
|
||||||
_(c10::bits16, Bits16) /* 22 */ \
|
|
||||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
|
||||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
|
||||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
|
||||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
|
||||||
_(uint16_t, UInt16) /* 27 */ \
|
|
||||||
_(uint32_t, UInt32) /* 28 */ \
|
|
||||||
_(uint64_t, UInt64) /* 29 */ \
|
|
||||||
_(dummy_uint1_7_t<1>, UInt1) /* 30 */ \
|
|
||||||
_(dummy_uint1_7_t<2>, UInt2) /* 31 */ \
|
|
||||||
_(dummy_uint1_7_t<3>, UInt3) /* 32 */ \
|
|
||||||
_(dummy_uint1_7_t<4>, UInt4) /* 33 */ \
|
|
||||||
_(dummy_uint1_7_t<5>, UInt5) /* 34 */ \
|
|
||||||
_(dummy_uint1_7_t<6>, UInt6) /* 35 */ \
|
|
||||||
_(dummy_uint1_7_t<7>, UInt7) /* 36 */ \
|
|
||||||
_(dummy_int1_7_t<1>, Int1) /* 37 */ \
|
|
||||||
_(dummy_int1_7_t<2>, Int2) /* 38 */ \
|
|
||||||
_(dummy_int1_7_t<3>, Int3) /* 39 */ \
|
|
||||||
_(dummy_int1_7_t<4>, Int4) /* 40 */ \
|
|
||||||
_(dummy_int1_7_t<5>, Int5) /* 41 */ \
|
|
||||||
_(dummy_int1_7_t<6>, Int6) /* 42 */ \
|
|
||||||
_(dummy_int1_7_t<7>, Int7) /* 43 */ \
|
|
||||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
|
||||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
|
||||||
|
|
||||||
enum class ScalarType : int8_t {
|
|
||||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
|
||||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
|
||||||
#undef DEFINE_ENUM_ST_ENUM_VAL_
|
|
||||||
Undefined,
|
|
||||||
NumOptions
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr uint16_t NumScalarTypes =
|
|
||||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
|
||||||
|
|
||||||
} // namespace c10
|
|
||||||
|
|
||||||
namespace torch::headeronly {
|
|
||||||
using c10::dummy_int1_7_t;
|
|
||||||
using c10::dummy_uint1_7_t;
|
|
||||||
using c10::NumScalarTypes;
|
|
||||||
using c10::ScalarType;
|
|
||||||
} // namespace torch::headeronly
|
|
@ -29,7 +29,6 @@ def define_torch_headeronly_ovrsource(name, is_mobile):
|
|||||||
public_include_directories = ["../.."],
|
public_include_directories = ["../.."],
|
||||||
public_preprocessor_flags = pp_flags,
|
public_preprocessor_flags = pp_flags,
|
||||||
public_raw_headers = native.glob([
|
public_raw_headers = native.glob([
|
||||||
"core/**/*.h",
|
|
||||||
"cpu/**/*.h",
|
"cpu/**/*.h",
|
||||||
"macros/*.h",
|
"macros/*.h",
|
||||||
"util/*.h",
|
"util/*.h",
|
||||||
|
Reference in New Issue
Block a user