mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][SDPA] Compute reference in test_triton_scaled_dot_product_attention_block_size_16_cuda_float32
in float64
(#146461)
Seems to currently fail with mismatches in the 1e-4 range presumably due to sdpa calling into the `MATH` backend here which is less fused than a triton kernel. Doing the ref computation in `float64` appears to fix it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146461 Approved by: https://github.com/drisspg
This commit is contained in:
@ -3763,9 +3763,11 @@ class TestSparseCompressedTritonKernels(TestCase):
|
||||
for scale in (None, 1. / 16):
|
||||
if scale is None and query.size(-1) == 0:
|
||||
scale = 1
|
||||
# We cast to double here as this dispatches to the MATH backend which
|
||||
# introduces additional rounding steps over the fused implementations
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(
|
||||
*broadcast_input(query, key, value, attn_mask), scale=scale
|
||||
)
|
||||
*broadcast_input(query.double(), key.double(), value.double(), attn_mask), scale=scale
|
||||
).to(dtype)
|
||||
|
||||
for mask_dtype in (torch.bool, dtype):
|
||||
res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale)
|
||||
|
Reference in New Issue
Block a user