mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Extend atomic operations to all int types (#158179)
That fixes `index_put(..., accumulate=True)` for all dtypes int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices Pull Request resolved: https://github.com/pytorch/pytorch/pull/158179 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #158064, #158178
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ea9cde598
commit
9ca080db87
@ -5,6 +5,29 @@
|
|||||||
using namespace metal;
|
using namespace metal;
|
||||||
using namespace c10::metal;
|
using namespace c10::metal;
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
namespace metal {
|
||||||
|
// There are no atomic 64-bit add in Metal yet, but this implements a consistent
|
||||||
|
// add I.e. if multiple threads are modify the same 64-bit value, results stored
|
||||||
|
// at the address will eventually be equal to its original value plus sum of all
|
||||||
|
// operands
|
||||||
|
template <>
|
||||||
|
struct AtomicType<long> {
|
||||||
|
using type = ::metal::atomic<uint>;
|
||||||
|
static inline void atomic_add(device type* data, long offset, long value) {
|
||||||
|
const auto value_bits = as_type<ulong>(value);
|
||||||
|
const uint low = static_cast<uint>(value_bits);
|
||||||
|
uint high = static_cast<uint>(value_bits >> 32);
|
||||||
|
auto ptr = data + (offset << 1);
|
||||||
|
auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed);
|
||||||
|
high += (old_low + low < old_low) ? 1 : 0;
|
||||||
|
atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
struct IndexAB {
|
struct IndexAB {
|
||||||
constant int64_t* indexArray;
|
constant int64_t* indexArray;
|
||||||
};
|
};
|
||||||
@ -211,7 +234,11 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
|
|||||||
|
|
||||||
REGISTER_INDEX_OP(put_accumulate, float, float);
|
REGISTER_INDEX_OP(put_accumulate, float, float);
|
||||||
REGISTER_INDEX_OP(put_accumulate, half, half);
|
REGISTER_INDEX_OP(put_accumulate, half, half);
|
||||||
|
REGISTER_INDEX_OP(put_accumulate, long, long);
|
||||||
REGISTER_INDEX_OP(put_accumulate, int, int);
|
REGISTER_INDEX_OP(put_accumulate, int, int);
|
||||||
|
REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||||
|
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||||
|
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||||
|
@ -121,8 +121,8 @@ static void validateInputData(const TensorIteratorBase& iter,
|
|||||||
const auto scalar_type = inputTensor.scalar_type();
|
const auto scalar_type = inputTensor.scalar_type();
|
||||||
|
|
||||||
if (accumulate) {
|
if (accumulate) {
|
||||||
// No atomic support for the rest of dtypes
|
// No atomic support for the complex dtypes
|
||||||
TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool);
|
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type));
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
|
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
|
||||||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
|
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
|
||||||
|
@ -35,15 +35,16 @@ static inline void atomic_add_helper(
|
|||||||
device ::metal::atomic<uint>* data,
|
device ::metal::atomic<uint>* data,
|
||||||
long offset,
|
long offset,
|
||||||
T value) {
|
T value) {
|
||||||
auto ptr = data + (offset >> 1);
|
constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
|
||||||
|
auto ptr = data + (offset / elem_per_enum);
|
||||||
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
||||||
union {
|
union {
|
||||||
uint i;
|
uint i;
|
||||||
T t[2];
|
T t[elem_per_enum];
|
||||||
} val;
|
} val;
|
||||||
do {
|
do {
|
||||||
val.i = old;
|
val.i = old;
|
||||||
val.t[offset & 1] += value;
|
val.t[offset & (elem_per_enum - 1)] += value;
|
||||||
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
||||||
ptr,
|
ptr,
|
||||||
&old,
|
&old,
|
||||||
@ -56,7 +57,31 @@ template <>
|
|||||||
struct AtomicType<half> {
|
struct AtomicType<half> {
|
||||||
using type = ::metal::atomic<uint>;
|
using type = ::metal::atomic<uint>;
|
||||||
static inline void atomic_add(device type* data, long offset, half value) {
|
static inline void atomic_add(device type* data, long offset, half value) {
|
||||||
atomic_add_helper<half>(data, offset, value);
|
atomic_add_helper(data, offset, value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct AtomicType<short> {
|
||||||
|
using type = ::metal::atomic<uint>;
|
||||||
|
static inline void atomic_add(device type* data, long offset, short value) {
|
||||||
|
atomic_add_helper(data, offset, value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct AtomicType<char> {
|
||||||
|
using type = ::metal::atomic<uint>;
|
||||||
|
static inline void atomic_add(device type* data, long offset, char value) {
|
||||||
|
atomic_add_helper(data, offset, value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct AtomicType<uchar> {
|
||||||
|
using type = ::metal::atomic<uint>;
|
||||||
|
static inline void atomic_add(device type* data, long offset, char value) {
|
||||||
|
atomic_add_helper(data, offset, value);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -541,13 +541,6 @@ if torch.backends.mps.is_available():
|
|||||||
# round not working properly for float16 and bfloat16
|
# round not working properly for float16 and bfloat16
|
||||||
"round": [torch.float16, torch.bfloat16],
|
"round": [torch.float16, torch.bfloat16],
|
||||||
"rounddecimals_0": [torch.bfloat16],
|
"rounddecimals_0": [torch.bfloat16],
|
||||||
# atomic operations not supported
|
|
||||||
"_unsafe_masked_index_put_accumulate": [
|
|
||||||
torch.int8,
|
|
||||||
torch.uint8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int64,
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if MACOS_VERSION < 14.0:
|
if MACOS_VERSION < 14.0:
|
||||||
@ -642,12 +635,6 @@ if torch.backends.mps.is_available():
|
|||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
],
|
],
|
||||||
"index_put": [
|
|
||||||
torch.uint8,
|
|
||||||
torch.int8,
|
|
||||||
torch.int16,
|
|
||||||
torch.int64,
|
|
||||||
],
|
|
||||||
# zero to negative integer powers are undefined
|
# zero to negative integer powers are undefined
|
||||||
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
|
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
|
||||||
"resize_": [torch.float16, torch.float32, torch.bfloat16],
|
"resize_": [torch.float16, torch.float32, torch.bfloat16],
|
||||||
|
Reference in New Issue
Block a user