mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
And delete confusing supported types check. Move all pseudo atomic (but eventually consistent) ops to `c10/metal/atomic.h` header Fixes https://github.com/pytorch/pytorch/issues/160034 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160159 Approved by: https://github.com/manuelcandales, https://github.com/dcci, https://github.com/Skylion007
178 lines
5.2 KiB
C++
178 lines
5.2 KiB
C++
#pragma once
|
|
#include <metal_atomic>
|
|
namespace c10 {
|
|
namespace metal {
|
|
|
|
// Atomic operations helper
|
|
template <typename T>
|
|
struct AtomicType {};
|
|
template <typename T>
|
|
using AtomicType_t = typename AtomicType<T>::type;
|
|
|
|
template <>
|
|
struct AtomicType<float> {
|
|
using type = ::metal::atomic<float>;
|
|
static inline void atomic_add(device type* data, long offset, float value) {
|
|
::metal::atomic_fetch_add_explicit(
|
|
data + offset, value, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<int> {
|
|
using type = ::metal::atomic<int>;
|
|
static inline void atomic_add(device type* data, long offset, int value) {
|
|
::metal::atomic_fetch_add_explicit(
|
|
data + offset, value, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
// As of Metal3.2 atomic operations are not supported on half-precision floats,
|
|
// so they must be simulated Using atomic compare and exchange over 32-bit
|
|
// atomic type
|
|
template <typename T>
|
|
static inline void atomic_add_helper(
|
|
device ::metal::atomic<uint>* data,
|
|
long offset,
|
|
T value) {
|
|
constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
|
|
auto ptr = data + (offset / elem_per_enum);
|
|
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
|
union {
|
|
uint i;
|
|
T t[elem_per_enum];
|
|
} val;
|
|
do {
|
|
val.i = old;
|
|
val.t[offset & (elem_per_enum - 1)] += value;
|
|
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
|
ptr,
|
|
&old,
|
|
val.i,
|
|
::metal::memory_order_relaxed,
|
|
::metal::memory_order_relaxed));
|
|
}
|
|
|
|
template <>
|
|
struct AtomicType<half> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, half value) {
|
|
atomic_add_helper(data, offset, value);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<short> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, short value) {
|
|
atomic_add_helper(data, offset, value);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<char> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, char value) {
|
|
atomic_add_helper(data, offset, value);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<uchar> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, char value) {
|
|
atomic_add_helper(data, offset, value);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<bfloat> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, bfloat value) {
|
|
atomic_add_helper<bfloat>(data, offset, value);
|
|
}
|
|
};
|
|
|
|
// Metal supports atomic_store_explicit for bools, but
|
|
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
|
|
// atomically modify unaligned memory, so fall back to compare and exchange
|
|
// trick As accumulation over booleans are just or operation, do nothing if
|
|
// value is false
|
|
template <>
|
|
struct AtomicType<bool> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, bool value) {
|
|
if (!value) {
|
|
return;
|
|
}
|
|
auto ptr = data + (offset >> 2);
|
|
auto old =
|
|
::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
|
union {
|
|
uint i;
|
|
bool t[4];
|
|
} val;
|
|
do {
|
|
val.i = old;
|
|
val.t[offset & 3] = true;
|
|
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
|
ptr,
|
|
&old,
|
|
val.i,
|
|
::metal::memory_order_relaxed,
|
|
::metal::memory_order_relaxed));
|
|
}
|
|
};
|
|
|
|
// ComplexHalf atomic op
|
|
template <>
|
|
struct AtomicType<half2> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, half2 value) {
|
|
auto ptr = data + offset;
|
|
auto old =
|
|
::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
|
while (!::metal::atomic_compare_exchange_weak_explicit(
|
|
ptr,
|
|
&old,
|
|
as_type<uint>(as_type<half2>(old) + value),
|
|
::metal::memory_order_relaxed,
|
|
::metal::memory_order_relaxed))
|
|
;
|
|
}
|
|
};
|
|
|
|
// There are no atomic 64-bit add in Metal yet, but templates below implements a
|
|
// consistent add I.e. if multiple threads are modify the same 64-bit value,
|
|
// results stored at the address will eventually be equal to its original value
|
|
// plus sum of all operands
|
|
template <>
|
|
struct AtomicType<long> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, long value) {
|
|
const auto value_bits = as_type<ulong>(value);
|
|
const uint low = static_cast<uint>(value_bits);
|
|
uint high = static_cast<uint>(value_bits >> 32);
|
|
auto ptr = data + (offset << 1);
|
|
auto old_low =
|
|
atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed);
|
|
high += (old_low + low < old_low) ? 1 : 0;
|
|
atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
// ComplexFloat atomic op, which again is not really atomic, but eventually
|
|
// consistent
|
|
template <>
|
|
struct AtomicType<float2> {
|
|
using type = ::metal::atomic<float>;
|
|
static inline void atomic_add(device type* data, long offset, float2 value) {
|
|
auto ptr = data + (offset << 1);
|
|
atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed);
|
|
atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
} // namespace metal
|
|
} // namespace c10
|