[MPS] Add some useful utils (#148448)

Like `is_compex_v`, `is_scalar_intergral_v`, `result_of` etc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148448
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #148398, #148399
This commit is contained in:
Nikita Shulga
2025-03-04 07:46:39 -08:00
committed by PyTorch MergeBot
parent f859722f70
commit e8900fbe4f

View File

@ -128,5 +128,22 @@ using vec4type_t = typename detail::vectypes<T>::type4;
template <typename T>
using opmath_t = typename detail::OpMathType<T>::type;
// TODO: Move it to type_traits header may be
template <typename F, typename... Args>
using result_of = decltype(::metal::declval<F>()(::metal::declval<Args>()...));
template <typename T>
constexpr constant bool is_complex_v =
::metal::is_same_v<T, float2> || ::metal::is_same_v<T, half2>;
template <typename T>
constexpr constant bool is_scalar_floating_point_v =
::metal::is_floating_point_v<T> && ::metal::is_scalar_v<T>;
template <typename T>
constexpr constant bool is_scalar_integral_v =
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
} // namespace metal
} // namespace c10