mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add some useful utils (#148448)
Like `is_compex_v`, `is_scalar_intergral_v`, `result_of` etc Pull Request resolved: https://github.com/pytorch/pytorch/pull/148448 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #148398, #148399
This commit is contained in:
committed by
PyTorch MergeBot
parent
f859722f70
commit
e8900fbe4f
@ -128,5 +128,22 @@ using vec4type_t = typename detail::vectypes<T>::type4;
|
||||
|
||||
template <typename T>
|
||||
using opmath_t = typename detail::OpMathType<T>::type;
|
||||
|
||||
// TODO: Move it to type_traits header may be
|
||||
template <typename F, typename... Args>
|
||||
using result_of = decltype(::metal::declval<F>()(::metal::declval<Args>()...));
|
||||
|
||||
template <typename T>
|
||||
constexpr constant bool is_complex_v =
|
||||
::metal::is_same_v<T, float2> || ::metal::is_same_v<T, half2>;
|
||||
|
||||
template <typename T>
|
||||
constexpr constant bool is_scalar_floating_point_v =
|
||||
::metal::is_floating_point_v<T> && ::metal::is_scalar_v<T>;
|
||||
|
||||
template <typename T>
|
||||
constexpr constant bool is_scalar_integral_v =
|
||||
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
Reference in New Issue
Block a user