[CUDA][SDPA] Compute reference in test_triton_scaled_dot_product_attention_block_size_16_cuda_float32 in float64 (#146461)

Seems to currently fail with mismatches in the 1e-4 range presumably due to sdpa calling into the `MATH` backend here which is less fused than a triton kernel. Doing the ref computation in `float64` appears to fix it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146461
Approved by: https://github.com/drisspg
This commit is contained in:
eqy
2025-02-06 23:28:56 +00:00
committed by PyTorch MergeBot
parent 2834fe5e93
commit 7bd7f735d4

View File

@ -3763,9 +3763,11 @@ class TestSparseCompressedTritonKernels(TestCase):
for scale in (None, 1. / 16):
if scale is None and query.size(-1) == 0:
scale = 1
# We cast to double here as this dispatches to the MATH backend which
# introduces additional rounding steps over the fused implementations
expected = torch.nn.functional.scaled_dot_product_attention(
*broadcast_input(query, key, value, attn_mask), scale=scale
)
*broadcast_input(query.double(), key.double(), value.double(), attn_mask), scale=scale
).to(dtype)
for mask_dtype in (torch.bool, dtype):
res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale)