Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164405
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
ghstack dependencies: #164350, #164354
This commit is contained in:
Pearu Peterson
2025-10-15 14:54:15 +03:00
committed by PyTorch MergeBot
parent 26f3803433
commit ca8bd5dbed
4 changed files with 44 additions and 19 deletions

View File

@ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
@ -308,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
// Returns a pair of strings representing the names for each dtype.
// The returned pair is (name, legacy_name_if_applicable)
C10_API std::pair<std::string, std::string> getDtypeNames(

View File

@ -53,3 +53,24 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
#undef DEFINE_CHECK
#undef TEST_FORALL
TEST(TestScalarType, toString) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}
TEST(TestScalarType, operator_left_shift) {
using torch::headeronly::ScalarType;
#define DEFINE_CHECK(_, name) \
{ \
std::stringstream ss; \
ss << ScalarType::name; \
EXPECT_EQ(ss.str(), #name); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
#undef DEFINE_CHECK
}

View File

@ -133,3 +133,5 @@ AT_FORALL_SCALAR_TYPES_AND7
AT_FORALL_QINT_TYPES
AT_FORALL_FLOAT8_TYPES
AT_FORALL_COMPLEX_TYPES
toString
<<

View File

@ -285,6 +285,25 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
} // namespace impl
inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;
switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}
inline std::ostream& operator<<(
std::ostream& stream,
at::ScalarType scalar_type) {
return stream << toString(scalar_type);
}
} // namespace c10
namespace torch::headeronly {
@ -295,4 +314,6 @@ using c10::ScalarType;
namespace impl {
using c10::impl::ScalarTypeToCPPTypeT;
} // namespace impl
using c10::toString;
using c10::operator<<;
} // namespace torch::headeronly