Address NaNs if SDPA is called with all values masked from query (#157727)

Fixes #156707

Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation.

These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157727
Approved by: https://github.com/malfet
This commit is contained in:
Joona Havukainen
2025-07-14 22:09:31 +00:00
committed by PyTorch MergeBot
parent bcf50636ba
commit 194539e9c3
2 changed files with 27 additions and 1 deletions

View File

@ -114,8 +114,22 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
// Overwrites expected NANs in sm with zeros.
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
truePredicateTensor:zeroTensor
falsePredicateTensor:sm
name:nil];
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;

View File

@ -9257,6 +9257,18 @@ class TestSDPA(TestCaseMPS):
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
# Regression test from: https://github.com/pytorch/pytorch/issues/156707
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_full_mask(self, dtype):
q = torch.randn(1, 1, 2, 4, dtype=dtype)
k = torch.randn(1, 1, 2, 4, dtype=dtype)
v = torch.randn(1, 1, 2, 4, dtype=dtype)
mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool)
out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps'))
self._compare_tensors(out_mps.cpu(), out_cpu)
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_3d_input(self, dtype):
head_num, seq_len, embed_dim = 16, 16, 80