Files
pytorch/c10/metal/common.h
Nikita Shulga e2a5c42e7e [BE][MPS] Build metal kernels of MacOS-14+ (#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary.
Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic.

Part of https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733
Approved by: https://github.com/dcci
ghstack dependencies: #159731, #159732
2025-08-03 20:53:58 +00:00

46 lines
1.2 KiB
C++

#pragma once
// Set of global constants that could be shareable between CPU and Metal code
#ifdef __METAL__
#include <metal_array>
#define C10_METAL_CONSTEXPR constant constexpr
#else
#include <array>
#define C10_METAL_CONSTEXPR constexpr
#endif
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11) \
_(BFloat16, 15)
namespace c10 {
namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
C10_METAL_CONSTEXPR unsigned simdgroup_size = 32;
#ifdef __METAL__
template <typename T, unsigned N>
using array = ::metal::array<T, N>;
#else
template <typename T, unsigned N>
using array = std::array<T, N>;
#endif
enum class ScalarType {
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
#undef _DEFINE_ENUM_VAL_
};
} // namespace metal
} // namespace c10