[MPS] Extend atomic operations to all int types (#158179)

That fixes `index_put(..., accumulate=True)` for all dtypes

int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158179
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #158064, #158178
This commit is contained in:
Nikita Shulga
2025-07-13 17:14:46 -10:00
committed by PyTorch MergeBot
parent 1ea9cde598
commit 9ca080db87
4 changed files with 58 additions and 19 deletions

View File

@ -5,6 +5,29 @@
using namespace metal;
using namespace c10::metal;
namespace c10 {
namespace metal {
// There are no atomic 64-bit add in Metal yet, but this implements a consistent
// add I.e. if multiple threads are modify the same 64-bit value, results stored
// at the address will eventually be equal to its original value plus sum of all
// operands
template <>
struct AtomicType<long> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, long value) {
const auto value_bits = as_type<ulong>(value);
const uint low = static_cast<uint>(value_bits);
uint high = static_cast<uint>(value_bits >> 32);
auto ptr = data + (offset << 1);
auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed);
high += (old_low + low < old_low) ? 1 : 0;
atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed);
}
};
} // namespace metal
} // namespace c10
struct IndexAB {
constant int64_t* indexArray;
};
@ -211,7 +234,11 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
REGISTER_INDEX_OP(put_accumulate, float, float);
REGISTER_INDEX_OP(put_accumulate, half, half);
REGISTER_INDEX_OP(put_accumulate, long, long);
REGISTER_INDEX_OP(put_accumulate, int, int);
REGISTER_INDEX_OP(put_accumulate, short, short);
REGISTER_INDEX_OP(put_accumulate, char, char);
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
REGISTER_INDEX_OP(put_accumulate, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);

View File

@ -121,8 +121,8 @@ static void validateInputData(const TensorIteratorBase& iter,
const auto scalar_type = inputTensor.scalar_type();
if (accumulate) {
// No atomic support for the rest of dtypes
TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool);
// No atomic support for the complex dtypes
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type));
} else {
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,

View File

@ -35,15 +35,16 @@ static inline void atomic_add_helper(
device ::metal::atomic<uint>* data,
long offset,
T value) {
auto ptr = data + (offset >> 1);
constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
auto ptr = data + (offset / elem_per_enum);
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
union {
uint i;
T t[2];
T t[elem_per_enum];
} val;
do {
val.i = old;
val.t[offset & 1] += value;
val.t[offset & (elem_per_enum - 1)] += value;
} while (!::metal::atomic_compare_exchange_weak_explicit(
ptr,
&old,
@ -56,7 +57,31 @@ template <>
struct AtomicType<half> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, half value) {
atomic_add_helper<half>(data, offset, value);
atomic_add_helper(data, offset, value);
}
};
template <>
struct AtomicType<short> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, short value) {
atomic_add_helper(data, offset, value);
}
};
template <>
struct AtomicType<char> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, char value) {
atomic_add_helper(data, offset, value);
}
};
template <>
struct AtomicType<uchar> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, char value) {
atomic_add_helper(data, offset, value);
}
};

View File

@ -541,13 +541,6 @@ if torch.backends.mps.is_available():
# round not working properly for float16 and bfloat16
"round": [torch.float16, torch.bfloat16],
"rounddecimals_0": [torch.bfloat16],
# atomic operations not supported
"_unsafe_masked_index_put_accumulate": [
torch.int8,
torch.uint8,
torch.int16,
torch.int64,
],
}
if MACOS_VERSION < 14.0:
@ -642,12 +635,6 @@ if torch.backends.mps.is_available():
torch.float16,
torch.bfloat16,
],
"index_put": [
torch.uint8,
torch.int8,
torch.int16,
torch.int64,
],
# zero to negative integer powers are undefined
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
"resize_": [torch.float16, torch.float32, torch.bfloat16],