mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Address NaNs if SDPA is called with all values masked from query (#157727)
Fixes #156707 Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation. These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157727 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
bcf50636ba
commit
194539e9c3
@ -114,8 +114,22 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
// Overwrites expected NANs in sm with zeros.
|
||||
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
|
||||
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
|
||||
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
|
||||
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
|
||||
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
|
||||
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
|
||||
truePredicateTensor:zeroTensor
|
||||
falsePredicateTensor:sm
|
||||
name:nil];
|
||||
|
||||
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
|
||||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
|
@ -9257,6 +9257,18 @@ class TestSDPA(TestCaseMPS):
|
||||
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
|
||||
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
|
||||
|
||||
# Regression test from: https://github.com/pytorch/pytorch/issues/156707
|
||||
@parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_sdpa_full_mask(self, dtype):
|
||||
q = torch.randn(1, 1, 2, 4, dtype=dtype)
|
||||
k = torch.randn(1, 1, 2, 4, dtype=dtype)
|
||||
v = torch.randn(1, 1, 2, 4, dtype=dtype)
|
||||
mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool)
|
||||
|
||||
out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps'))
|
||||
self._compare_tensors(out_mps.cpu(), out_cpu)
|
||||
|
||||
@parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_sdpa_3d_input(self, dtype):
|
||||
head_num, seq_len, embed_dim = 16, 16, 80
|
||||
|
Reference in New Issue
Block a user