Compare commits

...

1 Commits

Author SHA1 Message Date
994ff637bc Move CppTypeToScalarType to torch/headeronly
[ghstack-poisoned]
2025-11-11 19:23:38 -08:00
4 changed files with 28 additions and 14 deletions

View File

@ -33,20 +33,6 @@ namespace c10 {
// See [dtype Macros note] in torch/headeronly/core/ScalarType.h
// regarding macros.
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

View File

@ -13,6 +13,17 @@ TEST(TestScalarType, ScalarTypeToCPPTypeT) {
#undef DEFINE_CHECK
}
TEST(TestScalarType, CppTypeToScalarType) {
using c10::CppTypeToScalarType;
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
EXPECT_EQ(CppTypeToScalarType<TYPE>::value, ScalarType::SCALARTYPE);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
#define DEFINE_CHECK(TYPE, SCALARTYPE) \
{ \
EXPECT_EQ( \

View File

@ -161,6 +161,7 @@ COMPILE_TIME_MAX_DEVICE_TYPES
NumScalarTypes
ScalarType
# dummy_int1_7_t, dummy_uint1_7_t tested through ScalarType
CppTypeToScalarType
ScalarTypeToCPPTypeT
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX

View File

@ -266,6 +266,21 @@ enum class ScalarType : int8_t {
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
// Map from C++ type to ScalarType enum
template <typename T>
struct CppTypeToScalarType;
#define SPECIALIZE_CppTypeToScalarType(cpp_type, scalar_type) \
template <> \
struct CppTypeToScalarType<cpp_type> \
: std:: \
integral_constant<c10::ScalarType, c10::ScalarType::scalar_type> { \
};
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#undef SPECIALIZE_CppTypeToScalarType
namespace impl {
// These are used to map ScalarTypes to C++ types.
@ -340,6 +355,7 @@ inline ScalarType toUnderlying(ScalarType t) {
} // namespace c10
HIDDEN_NAMESPACE_BEGIN(torch, headeronly)
using c10::CppTypeToScalarType;
using c10::dummy_int1_7_t;
using c10::dummy_uint1_7_t;
using c10::NumScalarTypes;