Compare commits

...

1 Commits

Author SHA1 Message Date
f203b98062 Add check for MacOS26 to use different code path in SDPA 2025-11-17 15:38:19 -08:00
3 changed files with 125 additions and 56 deletions

View File

@ -22,6 +22,7 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_15_0_PLUS,
MACOS_VER_15_1_PLUS,
MACOS_VER_15_2_PLUS,
MACOS_VER_26_0_PLUS,
};
//-----------------------------------------------------------------

View File

@ -65,6 +65,7 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
static bool _macos_26_0_plus = is_os_version_at_least(26, 0);
switch (version) {
case MacOSVersion::MACOS_VER_14_4_PLUS:
@ -75,6 +76,8 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
return _macos_15_1_plus;
case MacOSVersion::MACOS_VER_15_2_PLUS:
return _macos_15_2_plus;
case MacOSVersion::MACOS_VER_26_0_PLUS:
return _macos_26_0_plus;
default:
return false;
}

View File

@ -69,75 +69,139 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto out = at::empty({batchSize, num_head, qSize, headSize}, query.options());
auto attn = at::empty({batchSize, num_head, qSize, maxSeqLength}, query.options());
auto scale_factor = sdp::calculate_scale(query, scale).expect_float();
static const bool is_macOS_26_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS);
@autoreleasepool {
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
std::to_string(attn_mask.has_value());
auto cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
CachedGraph* cachedGraph;
//if(is_macOS_26_0_or_newer) {
if(true) {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
if (is_causal) {
MPSShape* maskShape = @[@(qSize), @(maxSeqLength)];
auto x = [mpsGraph coordinateAlongAxis:-1 withShape:@[@(qSize), @1] name:nil];
auto y = [mpsGraph coordinateAlongAxis:-2 withShape:@[@1, @(maxSeqLength)] name:nil];
auto isLess = [mpsGraph lessThanOrEqualToWithPrimaryTensor:x secondaryTensor:y name:nil];
auto causalMask = [mpsGraph selectWithPredicateTensor:isLess
truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:qTensor.dataType]
falsePredicateTensor:[mpsGraph constantWithScalar:-INFINITY dataType:qTensor.dataType]
name:nil];
graph->maskTensor = causalMask;
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
// auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
// auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
// auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
// auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
// auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
//
// auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
// MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
// truePredicateTensor:zeroTensor
// falsePredicateTensor:sm
// name:nil];
//
// auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
MPSGraphTensor* output;
if(graph->maskTensor != nil) {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
maskTensor:graph->maskTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
} else {
output = [mpsGraph scaledDotProductAttentionWithQueryTensor:qTensor
keyTensor:kTensor
valueTensor:vTensor
scale:scale_factor
name:@"MPSGraph SDPA"];
}
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
// graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
} else {
cachedGraph =
LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&, q_ = query, k_ = key, v_ = value](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, q_);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, k_);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, v_);
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor
shape:getMPSShape({1})
dataType:MPSDataTypeFloat32];
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
// bug in MacOS15, without this trick SDPA leaks memory, adding 0.0f gets ignored(still takes SDPA sequence
// path which leaks)
auto oneTensor = [mpsGraph constantWithScalar:1e-20f shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:oneTensor name:nil];
}
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape({qSize, maxSeqLength})
dataType:MPSDataTypeBool];
causalMask = [mpsGraph bandPartWithTensor:causalMask numLower:-1 numUpper:0 name:nil];
auto minusInf = [mpsGraph constantWithScalar:-1e20 shape:maskedMM.shape dataType:maskedMM.dataType];
maskedMM = [mpsGraph selectWithPredicateTensor:causalMask
truePredicateTensor:maskedMM
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
}
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
auto vPlaceholder = Placeholder(cachedGraph->vTensor, value);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
// auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
NSDictionary* feeds = nil;
if (!attn_mask) {
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
@ -145,7 +209,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
}
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
// NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
}