[MPS] Extend index_put to half precision floats (#151869)

By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support

TODOs:
 - Get rid of indexing kernel and compute it directly when kernel is run
 - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
 - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-04-22 14:11:26 -07:00
committed by PyTorch MergeBot
parent b8f4dc5a9f
commit 3aecf2dc52
7 changed files with 88 additions and 130 deletions

View File

@ -1,5 +1,5 @@
#include <c10/metal/atomic.h>
#include <c10/metal/indexing.h>
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
@ -201,53 +201,8 @@ kernel_index_offsets<packed_uint3, ulong3>(
constant uint& num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]);
template <typename T, typename E, typename OffsetsT>
kernel void index_put_accumulate_native_dtypes(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant OffsetsT* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
constant int64_t* index_strides = (constant int64_t*)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T* out =
(device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant E* in =
(constant E*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
}
template <typename T>
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(
device void* addr,
T value) {
device atomic_uint* uintAddr = (device atomic_uint*)addr;
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
T updated = as_type<T>(expected) + value;
while (!atomic_compare_exchange_weak_explicit(
uintAddr,
&expected,
as_type<uint>(updated),
memory_order_relaxed,
memory_order_relaxed)) {
updated = as_type<T>(expected) + value;
}
}
template <typename T, typename OffsetsT>
kernel void atomic_index_put_accumulate(
kernel void index_put_accumulate(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
@ -258,7 +213,7 @@ kernel void atomic_index_put_accumulate(
uint thread_index [[thread_position_in_grid]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
constant int64_t* index_strides = (constant int64_t*)indexStrides;
int64_t offset = 0;
int64_t offset = offsets[thread_index].x;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
@ -267,56 +222,38 @@ kernel void atomic_index_put_accumulate(
}
offset += index * index_strides[i];
}
device void* out = (device void*)((device char*)outputData +
offsets[thread_index].x + offset);
constant T* in =
(constant T*)((constant char*)inputData + offsets[thread_index].y);
atomic_fetch_add_relaxed<T>(out, *in);
const auto in =
*(constant T*)((constant char*)inputData + offsets[thread_index].y);
AtomicType<T>::atomic_add(
reinterpret_cast<device AtomicType_t<T>*>(outputData),
offset / sizeof(T),
in);
}
template [[host_name("index_put_accumulate_32bit_float_idx32")]] kernel void
atomic_index_put_accumulate<float, uint3>(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_INDEX_PUT_ACCUMULATE(DTS, DTYPE, IDXS, IDX_DTYPE) \
template [[host_name("index_put_accumulate_" #DTS "_" #DTYPE \
"_" #IDXS)]] kernel void \
index_put_accumulate<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void* indexSizes [[buffer(1)]], \
constant void* indexStrides [[buffer(2)]], \
constant IDX_DTYPE* offsets [[buffer(3)]], \
constant void* inputData [[buffer(4)]], \
device void* outputData [[buffer(5)]], \
constant uint32_t& num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]])
template [[host_name("index_put_accumulate_32bit_float_idx64")]] kernel void
atomic_index_put_accumulate<float, ulong3>(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant ulong3* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx64, ulong3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx64, ulong3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx64, ulong3);
template [[host_name("index_put_accumulate_32bit_int_idx32")]] kernel void
index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template [[host_name("index_put_accumulate_32bit_int_idx64")]] kernel void
index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant ulong3* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx64, ulong3);
#endif
template <typename T>
kernel void masked_fill_scalar_dense(

View File

@ -105,7 +105,7 @@ void upsample_increment_value_bounded(
data,
n * strides.x + c * strides.y + access_y * strides.z +
access_x * strides.w,
value);
static_cast<scalar_t>(value));
}
template <typename T>

View File

@ -108,15 +108,12 @@ static std::string getIndexFunctionName(ScalarType scalar_type,
: (accumulate && (scalar_type != kBool)) ? "index_put_accumulate_"
: (serial ? "index_put_serial_" : "index_put_");
indexFunction += getBitSizeString(scalar_type);
indexFunction.append(getBitSizeString(scalar_type));
if (accumulate) {
TORCH_CHECK(scalar_type == ScalarType::Float || scalar_type == ScalarType::Int,
"Unsupported data type for accumulate case: ",
getMPSTypeString(scalar_type));
string dtypeString = (scalar_type == ScalarType::Float) ? "_float" : "_int";
indexFunction += dtypeString;
indexFunction.append(1, '_');
indexFunction.append(scalarToMetalTypeString(scalar_type));
}
indexFunction += use_64bit_indexing ? "_idx64" : "_idx32";
indexFunction.append(use_64bit_indexing ? "_idx64" : "_idx32");
return indexFunction;
}
@ -206,8 +203,7 @@ static void validateInputData(const TensorIteratorBase& iter,
if (accumulate) {
// No atomic support for the rest of dtypes
TORCH_CHECK(scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int ||
scalar_type == ScalarType::Bool);
TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool);
} else {
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,

View File

@ -32,18 +32,18 @@ struct AtomicType<int> {
// atomic type
template <typename T>
static inline void atomic_add_helper(
device ::metal::atomic<int>* data,
device ::metal::atomic<uint>* data,
long offset,
float value) {
T value) {
auto ptr = data + (offset >> 1);
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
union {
int i;
uint i;
T t[2];
} val;
do {
val.i = old;
val.t[offset & 1] += static_cast<T>(value);
val.t[offset & 1] += value;
} while (!::metal::atomic_compare_exchange_weak_explicit(
ptr,
&old,
@ -54,8 +54,8 @@ static inline void atomic_add_helper(
template <>
struct AtomicType<half> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, half value) {
atomic_add_helper<half>(data, offset, value);
}
};
@ -63,8 +63,8 @@ struct AtomicType<half> {
#if __METAL_VERSION__ >= 310
template <>
struct AtomicType<bfloat> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, bfloat value) {
atomic_add_helper<bfloat>(data, offset, value);
}
};

View File

@ -8015,7 +8015,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
),
)
@xfail_if_mps_unimplemented # RuntimeError: Expected scalar_type == ScalarType::Float
def test_index_put_fallback1(self):
def fn(a, b, c, d):
a = a.clone()
@ -8042,7 +8041,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
),
)
@xfail_if_mps_unimplemented # RuntimeError: Expected scalar_type == ScalarType::Float
def test_index_put_fallback2(self):
def fn(a, b, c, d, e):
a = a.clone()

View File

@ -79,7 +79,6 @@ def xfailIf(condition):
return func
return wrapper
# Same logic as test_cuda.py
if not torch.backends.mps.is_available():
print('MPS not available, skipping tests', file=sys.stderr)
@ -11782,6 +11781,9 @@ class TestConsistency(TestCaseMPS):
'norm', 'masked.normalize',
'arange', 'linspace',
'special.xlog1py',
# CPU accumulates sequantially, but GPU does in in parallel
'_unsafe_masked_index_put_accumulate',
}
FP32_LOW_PRECISION_LIST = {
@ -11799,7 +11801,7 @@ class TestConsistency(TestCaseMPS):
return (1e-4, 3e-5)
if op.name in self.FP16_LOW_PRECISION_LIST and dtype in [torch.float16, torch.bfloat16]:
return (1e-2, 1e-2) if dtype == torch.float16 else (5e-2, 5e-2)
return (2e-2, 1e-2) if dtype == torch.float16 else (5e-2, 5e-2)
if op.name in self.BF16_LOW_PRECISION_LIST and dtype == torch.bfloat16:
return (5e-2, 5e-2)
@ -11862,6 +11864,12 @@ class TestConsistency(TestCaseMPS):
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
# not normalized
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
mps_args[3] = F.normalize(mps_args[3])
cpu_args[3] = F.normalize(cpu_args[3])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
cpu_out = op(*cpu_args, **cpu_kwargs)
@ -11880,6 +11888,7 @@ class TestConsistency(TestCaseMPS):
if op.name in ["_upsample_bilinear2d_aa", "_upsample_bicubic2d_aa"] and cpu_kwargs.get("scale_factors") == [1.7, 0.9]:
# Similar to the above, float vs double precision aresults in slight error
atol, rtol = 2e-5, 2e-6
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
@ -11900,7 +11909,6 @@ class TestConsistency(TestCaseMPS):
#
# Forward check
#
forward_failed = False
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
@ -11912,6 +11920,12 @@ class TestConsistency(TestCaseMPS):
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
# not normalized
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
mps_args[3] = F.normalize(mps_args[3])
cpu_args[3] = F.normalize(cpu_args[3])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
cpu_out = op(*cpu_args, **cpu_kwargs)
@ -11930,11 +11944,6 @@ class TestConsistency(TestCaseMPS):
#
# Backward check
#
if forward_failed:
# We would've failed immediately anyway, but this error is clearer
# We error instead of continuing so that all_backward_pass would not be True
raise RuntimeError("Forward pass already failed")
cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
@ -11967,6 +11976,10 @@ class TestConsistency(TestCaseMPS):
):
atol = 1e-5
rtol = 1.5e-3
# Order of ops in unsafe_masked_index backward is not guaranteed
# which leads to larger errors
if op.name == "_unsafe_masked_index" and dtype == torch.float16:
atol, rtol = 3e-3, 3e-3
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
def test_fmax_mixed_dtypes(self, device):
@ -12206,6 +12219,28 @@ class TestMetalLibrary(TestCaseMPS):
self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5,
f"results are {y}, but all elements should have been {x_sum.item()}")
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
def test_atomic_add(self, dtype):
if dtype == torch.bfloat16 and MACOS_VERSION < 14.0:
raise unittest.SkipTest("bfloat requires MacOS-14+")
from torch._inductor.codegen.mps import DTYPE_TO_METAL
mdtype = DTYPE_TO_METAL[dtype]
lib = torch.mps.compile_shader(f"""
#include <c10/metal/atomic.h>
using namespace c10::metal;
kernel void atomic_add(device AtomicType<{mdtype}>::type* out,
constant {mdtype}* inc,
uint idx [[thread_position_in_grid]]) {{
AtomicType<{mdtype}>::atomic_add(out, idx & 1 ? 3 : 4, inc[idx]);
}}
""")
x = torch.arange(16, device="mps", dtype=dtype)
y = torch.arange(16, device="mps", dtype=dtype)
lib.atomic_add(x, y)
self.assertEqual(x[3], 67)
self.assertEqual(x[4], 60)
def test_argument_buffers(self):
lib = torch.mps.compile_shader("""
constant constexpr auto nbuffers = 64;

View File

@ -597,10 +597,8 @@ if torch.backends.mps.is_available():
torch.bool,
torch.int8,
torch.uint8,
torch.float16,
torch.int16,
torch.int64,
torch.bfloat16,
],
}
@ -703,8 +701,6 @@ if torch.backends.mps.is_available():
torch.int8,
torch.int16,
torch.int64,
torch.float16,
torch.bfloat16,
],
# zero to negative integer powers are undefined
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
@ -867,7 +863,6 @@ if torch.backends.mps.is_available():
"special.polygammaspecial_polygamma_n_0": [torch.float16],
"polygammapolygamma_n_0": [torch.float16],
# Unimplemented ops
"__getitem__": [torch.float16],
"_segment_reduce": [torch.float16, torch.float32],
"_chunk_cat": [torch.float16, torch.float32],
"_upsample_bilinear2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
@ -941,9 +936,6 @@ if torch.backends.mps.is_available():
"fmod": [torch.float16],
# round not working properly for float16
"round": [torch.float16],
# atomic operation in backward pass
"_unsafe_masked_index": [torch.float16],
"_unsafe_masked_index_put_accumulate": [torch.float16],
}
MACOS_BEFORE_13_3_XFAILLIST_GRAD = {