Make onnx export SDPA match aten behavior (#159973)

This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used.
@justinchuby

```python
import onnxruntime as ort
import torch

class ScaledDotProductAttention(torch.nn.Module):
    def forward(self, query, key, value, attn_mask):
        return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)

model = ScaledDotProductAttention()
attn_mask = torch.ones(2, 4, 8, 8).bool()  # boolean mask for attention
attn_mask[0, 0, 0, :] = False  # masking an entire row (padding token)
query = key = value = torch.randn(2, 4, 8, 16)
output = model(query, key, value, attn_mask)

torch.onnx.export(
    model,
    (query, key, value, attn_mask),
    "scaled_dot_product_attention.onnx",
    input_names=["query", "key", "value", "attn_mask"],
    output_names=["output"],
    dynamo=false, # or True,
)
ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx")

np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()}
onnx_outputs = ort_session.run(None, np_inputs)[0]

torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True)
```
fails the assertion because the ort model outputs nans.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159973
Approved by: https://github.com/xadupre, https://github.com/titaiwangms
This commit is contained in:
IlyasMoutawwakil
2025-08-07 04:06:04 +00:00
committed by PyTorch MergeBot
parent d4c1a08c89
commit c859ba7114

View File

@ -177,6 +177,7 @@ def scaled_dot_product_attention(
if symbolic_helper._is_none(attn_mask):
mul_qk_add = mul_qk
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
elif (
_type_utils.JitScalarType.from_value(attn_mask)
== _type_utils.JitScalarType.BOOL
@ -186,19 +187,24 @@ def scaled_dot_product_attention(
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
# the behavior of PyTorch with boolean masks.
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
elif _type_utils.JitScalarType.from_value(attn_mask) in (
_type_utils.JitScalarType.FLOAT,
_type_utils.JitScalarType.HALF,
_type_utils.JitScalarType.BFLOAT16,
):
mul_qk_add = g.op("Add", mul_qk, attn_mask)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
else:
raise ValueError(
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
)
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
if dropout_p != 0:
attn_weight = g.op(
"Dropout",