mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Extend index_put to half precision floats (#151869)
By reusing `c10/metal/atomic.h` This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support TODOs: - Get rid of indexing kernel and compute it directly when kernel is run - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869 Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8f4dc5a9f
commit
3aecf2dc52
@ -1,5 +1,5 @@
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/indexing.h>
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
@ -201,53 +201,8 @@ kernel_index_offsets<packed_uint3, ulong3>(
|
||||
constant uint& num_dimensions [[buffer(3)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
template <typename T, typename E, typename OffsetsT>
|
||||
kernel void index_put_accumulate_native_dtypes(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
constant OffsetsT* offsets [[buffer(3)]],
|
||||
constant void* inputData [[buffer(4)]],
|
||||
device void* outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
||||
int64_t offset = 0;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
if (index < 0) {
|
||||
index += index_sizes[i];
|
||||
}
|
||||
offset += index * index_strides[i];
|
||||
}
|
||||
device T* out =
|
||||
(device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
||||
constant E* in =
|
||||
(constant E*)((constant char*)inputData + offsets[thread_index].y);
|
||||
atomic_fetch_add_explicit(out, *in, memory_order_relaxed);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__attribute__((__always_inline__)) void atomic_fetch_add_relaxed(
|
||||
device void* addr,
|
||||
T value) {
|
||||
device atomic_uint* uintAddr = (device atomic_uint*)addr;
|
||||
uint expected = atomic_load_explicit(uintAddr, memory_order_relaxed);
|
||||
T updated = as_type<T>(expected) + value;
|
||||
while (!atomic_compare_exchange_weak_explicit(
|
||||
uintAddr,
|
||||
&expected,
|
||||
as_type<uint>(updated),
|
||||
memory_order_relaxed,
|
||||
memory_order_relaxed)) {
|
||||
updated = as_type<T>(expected) + value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename OffsetsT>
|
||||
kernel void atomic_index_put_accumulate(
|
||||
kernel void index_put_accumulate(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
@ -258,7 +213,7 @@ kernel void atomic_index_put_accumulate(
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
||||
int64_t offset = 0;
|
||||
int64_t offset = offsets[thread_index].x;
|
||||
for (uint32_t i = 0; i < num_indices; i++) {
|
||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
||||
@ -267,56 +222,38 @@ kernel void atomic_index_put_accumulate(
|
||||
}
|
||||
offset += index * index_strides[i];
|
||||
}
|
||||
device void* out = (device void*)((device char*)outputData +
|
||||
offsets[thread_index].x + offset);
|
||||
constant T* in =
|
||||
(constant T*)((constant char*)inputData + offsets[thread_index].y);
|
||||
atomic_fetch_add_relaxed<T>(out, *in);
|
||||
const auto in =
|
||||
*(constant T*)((constant char*)inputData + offsets[thread_index].y);
|
||||
AtomicType<T>::atomic_add(
|
||||
reinterpret_cast<device AtomicType_t<T>*>(outputData),
|
||||
offset / sizeof(T),
|
||||
in);
|
||||
}
|
||||
|
||||
template [[host_name("index_put_accumulate_32bit_float_idx32")]] kernel void
|
||||
atomic_index_put_accumulate<float, uint3>(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
constant void* inputData [[buffer(4)]],
|
||||
device void* outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
#define REGISTER_INDEX_PUT_ACCUMULATE(DTS, DTYPE, IDXS, IDX_DTYPE) \
|
||||
template [[host_name("index_put_accumulate_" #DTS "_" #DTYPE \
|
||||
"_" #IDXS)]] kernel void \
|
||||
index_put_accumulate<DTYPE, IDX_DTYPE>( \
|
||||
constant IndexAB * indexAB [[buffer(0)]], \
|
||||
constant void* indexSizes [[buffer(1)]], \
|
||||
constant void* indexStrides [[buffer(2)]], \
|
||||
constant IDX_DTYPE* offsets [[buffer(3)]], \
|
||||
constant void* inputData [[buffer(4)]], \
|
||||
device void* outputData [[buffer(5)]], \
|
||||
constant uint32_t& num_indices [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
template [[host_name("index_put_accumulate_32bit_float_idx64")]] kernel void
|
||||
atomic_index_put_accumulate<float, ulong3>(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
constant ulong3* offsets [[buffer(3)]],
|
||||
constant void* inputData [[buffer(4)]],
|
||||
device void* outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx32, uint3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx64, ulong3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx32, uint3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx64, ulong3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx32, uint3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx64, ulong3);
|
||||
|
||||
template [[host_name("index_put_accumulate_32bit_int_idx32")]] kernel void
|
||||
index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
constant void* inputData [[buffer(4)]],
|
||||
device void* outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("index_put_accumulate_32bit_int_idx64")]] kernel void
|
||||
index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
|
||||
constant IndexAB* indexAB [[buffer(0)]],
|
||||
constant void* indexSizes [[buffer(1)]],
|
||||
constant void* indexStrides [[buffer(2)]],
|
||||
constant ulong3* offsets [[buffer(3)]],
|
||||
constant void* inputData [[buffer(4)]],
|
||||
device void* outputData [[buffer(5)]],
|
||||
constant uint32_t& num_indices [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx32, uint3);
|
||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx64, ulong3);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
kernel void masked_fill_scalar_dense(
|
||||
|
@ -105,7 +105,7 @@ void upsample_increment_value_bounded(
|
||||
data,
|
||||
n * strides.x + c * strides.y + access_y * strides.z +
|
||||
access_x * strides.w,
|
||||
value);
|
||||
static_cast<scalar_t>(value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -108,15 +108,12 @@ static std::string getIndexFunctionName(ScalarType scalar_type,
|
||||
: (accumulate && (scalar_type != kBool)) ? "index_put_accumulate_"
|
||||
: (serial ? "index_put_serial_" : "index_put_");
|
||||
|
||||
indexFunction += getBitSizeString(scalar_type);
|
||||
indexFunction.append(getBitSizeString(scalar_type));
|
||||
if (accumulate) {
|
||||
TORCH_CHECK(scalar_type == ScalarType::Float || scalar_type == ScalarType::Int,
|
||||
"Unsupported data type for accumulate case: ",
|
||||
getMPSTypeString(scalar_type));
|
||||
string dtypeString = (scalar_type == ScalarType::Float) ? "_float" : "_int";
|
||||
indexFunction += dtypeString;
|
||||
indexFunction.append(1, '_');
|
||||
indexFunction.append(scalarToMetalTypeString(scalar_type));
|
||||
}
|
||||
indexFunction += use_64bit_indexing ? "_idx64" : "_idx32";
|
||||
indexFunction.append(use_64bit_indexing ? "_idx64" : "_idx32");
|
||||
return indexFunction;
|
||||
}
|
||||
|
||||
@ -206,8 +203,7 @@ static void validateInputData(const TensorIteratorBase& iter,
|
||||
|
||||
if (accumulate) {
|
||||
// No atomic support for the rest of dtypes
|
||||
TORCH_CHECK(scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int ||
|
||||
scalar_type == ScalarType::Bool);
|
||||
TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool);
|
||||
} else {
|
||||
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
|
||||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
|
||||
|
@ -32,18 +32,18 @@ struct AtomicType<int> {
|
||||
// atomic type
|
||||
template <typename T>
|
||||
static inline void atomic_add_helper(
|
||||
device ::metal::atomic<int>* data,
|
||||
device ::metal::atomic<uint>* data,
|
||||
long offset,
|
||||
float value) {
|
||||
T value) {
|
||||
auto ptr = data + (offset >> 1);
|
||||
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
||||
union {
|
||||
int i;
|
||||
uint i;
|
||||
T t[2];
|
||||
} val;
|
||||
do {
|
||||
val.i = old;
|
||||
val.t[offset & 1] += static_cast<T>(value);
|
||||
val.t[offset & 1] += value;
|
||||
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
||||
ptr,
|
||||
&old,
|
||||
@ -54,8 +54,8 @@ static inline void atomic_add_helper(
|
||||
|
||||
template <>
|
||||
struct AtomicType<half> {
|
||||
using type = ::metal::atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, half value) {
|
||||
atomic_add_helper<half>(data, offset, value);
|
||||
}
|
||||
};
|
||||
@ -63,8 +63,8 @@ struct AtomicType<half> {
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct AtomicType<bfloat> {
|
||||
using type = ::metal::atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
using type = ::metal::atomic<uint>;
|
||||
static inline void atomic_add(device type* data, long offset, bfloat value) {
|
||||
atomic_add_helper<bfloat>(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
@ -8015,7 +8015,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
),
|
||||
)
|
||||
|
||||
@xfail_if_mps_unimplemented # RuntimeError: Expected scalar_type == ScalarType::Float
|
||||
def test_index_put_fallback1(self):
|
||||
def fn(a, b, c, d):
|
||||
a = a.clone()
|
||||
@ -8042,7 +8041,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
),
|
||||
)
|
||||
|
||||
@xfail_if_mps_unimplemented # RuntimeError: Expected scalar_type == ScalarType::Float
|
||||
def test_index_put_fallback2(self):
|
||||
def fn(a, b, c, d, e):
|
||||
a = a.clone()
|
||||
|
@ -79,7 +79,6 @@ def xfailIf(condition):
|
||||
return func
|
||||
return wrapper
|
||||
|
||||
|
||||
# Same logic as test_cuda.py
|
||||
if not torch.backends.mps.is_available():
|
||||
print('MPS not available, skipping tests', file=sys.stderr)
|
||||
@ -11782,6 +11781,9 @@ class TestConsistency(TestCaseMPS):
|
||||
'norm', 'masked.normalize',
|
||||
'arange', 'linspace',
|
||||
'special.xlog1py',
|
||||
|
||||
# CPU accumulates sequantially, but GPU does in in parallel
|
||||
'_unsafe_masked_index_put_accumulate',
|
||||
}
|
||||
|
||||
FP32_LOW_PRECISION_LIST = {
|
||||
@ -11799,7 +11801,7 @@ class TestConsistency(TestCaseMPS):
|
||||
return (1e-4, 3e-5)
|
||||
|
||||
if op.name in self.FP16_LOW_PRECISION_LIST and dtype in [torch.float16, torch.bfloat16]:
|
||||
return (1e-2, 1e-2) if dtype == torch.float16 else (5e-2, 5e-2)
|
||||
return (2e-2, 1e-2) if dtype == torch.float16 else (5e-2, 5e-2)
|
||||
|
||||
if op.name in self.BF16_LOW_PRECISION_LIST and dtype == torch.bfloat16:
|
||||
return (5e-2, 5e-2)
|
||||
@ -11862,6 +11864,12 @@ class TestConsistency(TestCaseMPS):
|
||||
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
|
||||
mps_args[1] = cpu_args[1]
|
||||
|
||||
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
|
||||
# not normalized
|
||||
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
|
||||
mps_args[3] = F.normalize(mps_args[3])
|
||||
cpu_args[3] = F.normalize(cpu_args[3])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
cpu_out = op(*cpu_args, **cpu_kwargs)
|
||||
@ -11880,6 +11888,7 @@ class TestConsistency(TestCaseMPS):
|
||||
if op.name in ["_upsample_bilinear2d_aa", "_upsample_bicubic2d_aa"] and cpu_kwargs.get("scale_factors") == [1.7, 0.9]:
|
||||
# Similar to the above, float vs double precision aresults in slight error
|
||||
atol, rtol = 2e-5, 2e-6
|
||||
|
||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||
|
||||
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
|
||||
@ -11900,7 +11909,6 @@ class TestConsistency(TestCaseMPS):
|
||||
#
|
||||
# Forward check
|
||||
#
|
||||
forward_failed = False
|
||||
mps_sample = transform_opinfo_sample_to_mps(cpu_sample)
|
||||
|
||||
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
|
||||
@ -11912,6 +11920,12 @@ class TestConsistency(TestCaseMPS):
|
||||
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
|
||||
mps_args[1] = cpu_args[1]
|
||||
|
||||
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
|
||||
# not normalized
|
||||
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
|
||||
mps_args[3] = F.normalize(mps_args[3])
|
||||
cpu_args[3] = F.normalize(cpu_args[3])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
cpu_out = op(*cpu_args, **cpu_kwargs)
|
||||
@ -11930,11 +11944,6 @@ class TestConsistency(TestCaseMPS):
|
||||
#
|
||||
# Backward check
|
||||
#
|
||||
if forward_failed:
|
||||
# We would've failed immediately anyway, but this error is clearer
|
||||
# We error instead of continuing so that all_backward_pass would not be True
|
||||
raise RuntimeError("Forward pass already failed")
|
||||
|
||||
cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
|
||||
mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
|
||||
|
||||
@ -11967,6 +11976,10 @@ class TestConsistency(TestCaseMPS):
|
||||
):
|
||||
atol = 1e-5
|
||||
rtol = 1.5e-3
|
||||
# Order of ops in unsafe_masked_index backward is not guaranteed
|
||||
# which leads to larger errors
|
||||
if op.name == "_unsafe_masked_index" and dtype == torch.float16:
|
||||
atol, rtol = 3e-3, 3e-3
|
||||
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fmax_mixed_dtypes(self, device):
|
||||
@ -12206,6 +12219,28 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5,
|
||||
f"results are {y}, but all elements should have been {x_sum.item()}")
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
|
||||
def test_atomic_add(self, dtype):
|
||||
if dtype == torch.bfloat16 and MACOS_VERSION < 14.0:
|
||||
raise unittest.SkipTest("bfloat requires MacOS-14+")
|
||||
from torch._inductor.codegen.mps import DTYPE_TO_METAL
|
||||
mdtype = DTYPE_TO_METAL[dtype]
|
||||
lib = torch.mps.compile_shader(f"""
|
||||
#include <c10/metal/atomic.h>
|
||||
using namespace c10::metal;
|
||||
kernel void atomic_add(device AtomicType<{mdtype}>::type* out,
|
||||
constant {mdtype}* inc,
|
||||
uint idx [[thread_position_in_grid]]) {{
|
||||
AtomicType<{mdtype}>::atomic_add(out, idx & 1 ? 3 : 4, inc[idx]);
|
||||
}}
|
||||
|
||||
""")
|
||||
x = torch.arange(16, device="mps", dtype=dtype)
|
||||
y = torch.arange(16, device="mps", dtype=dtype)
|
||||
lib.atomic_add(x, y)
|
||||
self.assertEqual(x[3], 67)
|
||||
self.assertEqual(x[4], 60)
|
||||
|
||||
def test_argument_buffers(self):
|
||||
lib = torch.mps.compile_shader("""
|
||||
constant constexpr auto nbuffers = 64;
|
||||
|
@ -597,10 +597,8 @@ if torch.backends.mps.is_available():
|
||||
torch.bool,
|
||||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.float16,
|
||||
torch.int16,
|
||||
torch.int64,
|
||||
torch.bfloat16,
|
||||
],
|
||||
}
|
||||
|
||||
@ -703,8 +701,6 @@ if torch.backends.mps.is_available():
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int64,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
],
|
||||
# zero to negative integer powers are undefined
|
||||
"__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64],
|
||||
@ -867,7 +863,6 @@ if torch.backends.mps.is_available():
|
||||
"special.polygammaspecial_polygamma_n_0": [torch.float16],
|
||||
"polygammapolygamma_n_0": [torch.float16],
|
||||
# Unimplemented ops
|
||||
"__getitem__": [torch.float16],
|
||||
"_segment_reduce": [torch.float16, torch.float32],
|
||||
"_chunk_cat": [torch.float16, torch.float32],
|
||||
"_upsample_bilinear2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
|
||||
@ -941,9 +936,6 @@ if torch.backends.mps.is_available():
|
||||
"fmod": [torch.float16],
|
||||
# round not working properly for float16
|
||||
"round": [torch.float16],
|
||||
# atomic operation in backward pass
|
||||
"_unsafe_masked_index": [torch.float16],
|
||||
"_unsafe_masked_index_put_accumulate": [torch.float16],
|
||||
}
|
||||
|
||||
MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
|
||||
|
Reference in New Issue
Block a user