mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164405 Approved by: https://github.com/Skylion007, https://github.com/janeyx99 ghstack dependencies: #164350, #164354
This commit is contained in:
committed by
PyTorch MergeBot
parent
26f3803433
commit
ca8bd5dbed
@ -52,19 +52,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|||||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
|
||||||
#undef DEFINE_CONSTANT
|
#undef DEFINE_CONSTANT
|
||||||
|
|
||||||
inline const char* toString(ScalarType t) {
|
|
||||||
#define DEFINE_CASE(_, name) \
|
|
||||||
case ScalarType::name: \
|
|
||||||
return #name;
|
|
||||||
|
|
||||||
switch (t) {
|
|
||||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
|
||||||
default:
|
|
||||||
return "UNKNOWN_SCALAR";
|
|
||||||
}
|
|
||||||
#undef DEFINE_CASE
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elementSize(ScalarType t) {
|
inline size_t elementSize(ScalarType t) {
|
||||||
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
|
||||||
case ScalarType::name: \
|
case ScalarType::name: \
|
||||||
@ -308,12 +295,6 @@ inline bool canCast(const ScalarType from, const ScalarType to) {
|
|||||||
|
|
||||||
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
C10_API ScalarType promoteTypes(ScalarType a, ScalarType b);
|
||||||
|
|
||||||
inline std::ostream& operator<<(
|
|
||||||
std::ostream& stream,
|
|
||||||
at::ScalarType scalar_type) {
|
|
||||||
return stream << toString(scalar_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a pair of strings representing the names for each dtype.
|
// Returns a pair of strings representing the names for each dtype.
|
||||||
// The returned pair is (name, legacy_name_if_applicable)
|
// The returned pair is (name, legacy_name_if_applicable)
|
||||||
C10_API std::pair<std::string, std::string> getDtypeNames(
|
C10_API std::pair<std::string, std::string> getDtypeNames(
|
||||||
|
@ -53,3 +53,24 @@ TEST_FORALL(AT_FORALL_COMPLEX_TYPES, 2)
|
|||||||
|
|
||||||
#undef DEFINE_CHECK
|
#undef DEFINE_CHECK
|
||||||
#undef TEST_FORALL
|
#undef TEST_FORALL
|
||||||
|
|
||||||
|
TEST(TestScalarType, toString) {
|
||||||
|
using torch::headeronly::ScalarType;
|
||||||
|
|
||||||
|
#define DEFINE_CHECK(_, name) EXPECT_EQ(toString(ScalarType::name), #name);
|
||||||
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||||
|
#undef DEFINE_CHECK
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TestScalarType, operator_left_shift) {
|
||||||
|
using torch::headeronly::ScalarType;
|
||||||
|
|
||||||
|
#define DEFINE_CHECK(_, name) \
|
||||||
|
{ \
|
||||||
|
std::stringstream ss; \
|
||||||
|
ss << ScalarType::name; \
|
||||||
|
EXPECT_EQ(ss.str(), #name); \
|
||||||
|
}
|
||||||
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CHECK);
|
||||||
|
#undef DEFINE_CHECK
|
||||||
|
}
|
||||||
|
@ -133,3 +133,5 @@ AT_FORALL_SCALAR_TYPES_AND7
|
|||||||
AT_FORALL_QINT_TYPES
|
AT_FORALL_QINT_TYPES
|
||||||
AT_FORALL_FLOAT8_TYPES
|
AT_FORALL_FLOAT8_TYPES
|
||||||
AT_FORALL_COMPLEX_TYPES
|
AT_FORALL_COMPLEX_TYPES
|
||||||
|
toString
|
||||||
|
<<
|
||||||
|
@ -285,6 +285,25 @@ using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType<N>::type;
|
|||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
|
||||||
|
inline const char* toString(ScalarType t) {
|
||||||
|
#define DEFINE_CASE(_, name) \
|
||||||
|
case ScalarType::name: \
|
||||||
|
return #name;
|
||||||
|
|
||||||
|
switch (t) {
|
||||||
|
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
|
||||||
|
default:
|
||||||
|
return "UNKNOWN_SCALAR";
|
||||||
|
}
|
||||||
|
#undef DEFINE_CASE
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(
|
||||||
|
std::ostream& stream,
|
||||||
|
at::ScalarType scalar_type) {
|
||||||
|
return stream << toString(scalar_type);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
namespace torch::headeronly {
|
namespace torch::headeronly {
|
||||||
@ -295,4 +314,6 @@ using c10::ScalarType;
|
|||||||
namespace impl {
|
namespace impl {
|
||||||
using c10::impl::ScalarTypeToCPPTypeT;
|
using c10::impl::ScalarTypeToCPPTypeT;
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
using c10::toString;
|
||||||
|
using c10::operator<<;
|
||||||
} // namespace torch::headeronly
|
} // namespace torch::headeronly
|
||||||
|
Reference in New Issue
Block a user