mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make onnx export SDPA match aten behavior (#159973)
This PR makes onnx sdpa export match the behavior of aten sdpa when boolean mask is used. @justinchuby ```python import onnxruntime as ort import torch class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): return torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) model = ScaledDotProductAttention() attn_mask = torch.ones(2, 4, 8, 8).bool() # boolean mask for attention attn_mask[0, 0, 0, :] = False # masking an entire row (padding token) query = key = value = torch.randn(2, 4, 8, 16) output = model(query, key, value, attn_mask) torch.onnx.export( model, (query, key, value, attn_mask), "scaled_dot_product_attention.onnx", input_names=["query", "key", "value", "attn_mask"], output_names=["output"], dynamo=false, # or True, ) ort_session = ort.InferenceSession("scaled_dot_product_attention.onnx") np_inputs = {"query": query.numpy(), "key": key.numpy(), "value": value.numpy(), "attn_mask": attn_mask.numpy()} onnx_outputs = ort_session.run(None, np_inputs)[0] torch.testing.assert_close(output, torch.tensor(onnx_outputs), equal_nan=True) ``` fails the assertion because the ort model outputs nans. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159973 Approved by: https://github.com/xadupre, https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
d4c1a08c89
commit
c859ba7114
@ -177,6 +177,7 @@ def scaled_dot_product_attention(
|
||||
|
||||
if symbolic_helper._is_none(attn_mask):
|
||||
mul_qk_add = mul_qk
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
elif (
|
||||
_type_utils.JitScalarType.from_value(attn_mask)
|
||||
== _type_utils.JitScalarType.BOOL
|
||||
@ -186,19 +187,24 @@ def scaled_dot_product_attention(
|
||||
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
|
||||
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
|
||||
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
|
||||
# This is because there's no safe softmax imp in ONNX, so we need to handle NaN values explicitly to match
|
||||
# the behavior of PyTorch with boolean masks.
|
||||
attn_weight = g.op("Where", g.op("IsNaN", attn_weight), const_zero, attn_weight)
|
||||
elif _type_utils.JitScalarType.from_value(attn_mask) in (
|
||||
_type_utils.JitScalarType.FLOAT,
|
||||
_type_utils.JitScalarType.HALF,
|
||||
_type_utils.JitScalarType.BFLOAT16,
|
||||
):
|
||||
mul_qk_add = g.op("Add", mul_qk, attn_mask)
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}"
|
||||
)
|
||||
|
||||
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
|
||||
|
||||
if dropout_p != 0:
|
||||
attn_weight = g.op(
|
||||
"Dropout",
|
||||
|
Reference in New Issue
Block a user