mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Extend atomic operations to all int types (#158179)
That fixes `index_put(..., accumulate=True)` for all dtypes int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices Pull Request resolved: https://github.com/pytorch/pytorch/pull/158179 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #158064, #158178
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ea9cde598
commit
9ca080db87
@ -5,6 +5,29 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
// There are no atomic 64-bit add in Metal yet, but this implements a consistent
|
||||
// add I.e. if multiple threads are modify the same 64-bit value, results stored
|
||||
// at the address will eventually be equal to its original value plus sum of all
|
||||
// operands
|
||||
template <>
|
||||
struct AtomicType<long> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, long value) {
|
||||
const auto value_bits = as_type<ulong>(value);
|
||||
const uint low = static_cast<uint>(value_bits);
|
||||
uint high = static_cast<uint>(value_bits >> 32);
|
||||
auto ptr = data + (offset << 1);
|
||||
auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed);
|
||||
high += (old_low + low < old_low) ? 1 : 0;
|
||||
atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
struct IndexAB {
|
||||
constant int64_t* indexArray;
|
||||
};
|
||||
@ -211,7 +234,11 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
|
||||
|
||||
REGISTER_INDEX_OP(put_accumulate, float, float);
|
||||
REGISTER_INDEX_OP(put_accumulate, half, half);
|
||||
REGISTER_INDEX_OP(put_accumulate, long, long);
|
||||
REGISTER_INDEX_OP(put_accumulate, int, int);
|
||||
REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||
|
@ -121,8 +121,8 @@ static void validateInputData(const TensorIteratorBase& iter,
|
||||
const auto scalar_type = inputTensor.scalar_type();
|
||||
|
||||
if (accumulate) {
|
||||
// No atomic support for the rest of dtypes
|
||||
TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool);
|
||||
// No atomic support for the complex dtypes
|
||||
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type));
|
||||
} else {
|
||||
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
|
||||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
|
||||
|
@ -35,15 +35,16 @@ static inline void atomic_add_helper(
|
||||
device ::metal::atomic<uint>* data,
|
||||
long offset,
|
||||
T value) {
|
||||
auto ptr = data + (offset >> 1);
|
||||
constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
|
||||
auto ptr = data + (offset / elem_per_enum);
|
||||
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
||||
union {
|
||||
uint i;
|
||||
T t[2];
|
||||
T t[elem_per_enum];
|
||||
} val;
|
||||
do {
|
||||
val.i = old;
|
||||
val.t[offset & 1] += value;
|
||||
val.t[offset & (elem_per_enum - 1)] += value;
|
||||
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
||||
ptr,
|
||||
&old,
|
||||
@ -56,7 +57,31 @@ template <>
|
||||
struct AtomicType<half> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, half value) {
|
||||
atomic_add_helper<half>(data, offset, value);
|
||||
atomic_add_helper(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicType<short> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, short value) {
|
||||
atomic_add_helper(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicType<char> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, char value) {
|
||||
atomic_add_helper(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicType<uchar> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, char value) {
|
||||
atomic_add_helper(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -541,13 +541,6 @@ if torch.backends.mps.is_available():
|
||||
# round not working properly for float16 and bfloat16
|
||||
"round": [torch.float16, torch.bfloat16],
|
||||
"rounddecimals_0": [torch.bfloat16],
|
||||
# atomic operations not supported
|
||||
"_unsafe_masked_index_put_accumulate": [
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.int16,
|
||||
torch.int64,
|
||||
],
|
||||
}
|
||||
|
||||
if MACOS_VERSION < 14.0:
|
||||
@ -642,12 +635,6 @@ if torch.backends.mps.is_available():
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
],
|
||||
"index_put": [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int64,
|
||||
],
|
||||
# zero to negative integer powers are undefined
|
||||
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
|
||||
"resize_": [torch.float16, torch.float32, torch.bfloat16],
|
||||
|
Reference in New Issue
Block a user