[Intel GPU] undo broadcast on zero stride tensor for SDPA (#151976)

Fix https://github.com/pytorch/pytorch/issues/152290.

The model **hubert** uses aten::expand to build attention mask by broadcasting. Pytorch uses strides[d]=0 to represent broadcast, which is not supported by oneDNN.  This PR handles this scenario.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151976
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
This commit is contained in:
fengqing.lu
2025-05-14 16:09:03 +00:00
committed by PyTorch MergeBot
parent 1f48bab377
commit de92296bbb
2 changed files with 55 additions and 10 deletions

View File

@ -49,16 +49,39 @@ struct SDPALogicalParams {
"Only FP16/BF16/FP32 datatypes are currently supported");
const dims scalar_shape = {1};
std::vector<logical_tensor> inputLogicalTensors;
at::Tensor reshaped_query = query_;
at::Tensor reshaped_key = key_;
at::Tensor reshaped_value = value_;
at::Tensor reshaped_output = output_;
at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor());
if (at::native::onednn::is_broadcast(reshaped_query)) {
at::native::onednn::undo_broadcast(reshaped_query);
}
if (at::native::onednn::is_broadcast(reshaped_key)) {
at::native::onednn::undo_broadcast(reshaped_key);
}
if (at::native::onednn::is_broadcast(reshaped_value)) {
at::native::onednn::undo_broadcast(reshaped_value);
}
if (at::native::onednn::is_broadcast(reshaped_output)) {
at::native::onednn::undo_broadcast(reshaped_output);
}
if (attn_mask_.has_value() &&
at::native::onednn::is_broadcast(reshaped_attn_mask)) {
at::native::onednn::undo_broadcast(reshaped_attn_mask);
}
query = {
static_cast<size_t>(TensorID::query),
dtype,
query_.sizes().vec(),
query_.strides().vec()};
reshaped_query.sizes().vec(),
reshaped_query.strides().vec()};
key = {
static_cast<size_t>(TensorID::key),
dtype,
key_.sizes().vec(),
key_.strides().vec()};
reshaped_key.sizes().vec(),
reshaped_key.strides().vec()};
scale = {
static_cast<size_t>(TensorID::scale),
dtype,
@ -77,19 +100,19 @@ struct SDPALogicalParams {
attn_mask = {
static_cast<size_t>(TensorID::attn_mask),
dtype,
attn_mask_->sizes().vec(),
attn_mask_->strides().vec()};
reshaped_attn_mask.sizes().vec(),
reshaped_attn_mask.strides().vec()};
}
value = {
static_cast<size_t>(TensorID::value),
dtype,
value_.sizes().vec(),
value_.strides().vec()};
reshaped_value.sizes().vec(),
reshaped_value.strides().vec()};
output = {
static_cast<size_t>(TensorID::output),
dtype,
output_.sizes().vec(),
output_.strides().vec()};
reshaped_output.sizes().vec(),
reshaped_output.strides().vec()};
}
std::vector<logical_tensor> get_input() const {
std::vector<logical_tensor> input = {query, key, scale};

View File

@ -4006,6 +4006,28 @@ class TestSDPAXpuOnly(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
_ = F.scaled_dot_product_attention(q, k, v)
def test_fused_attention_broadcasted_input(self, device):
dtype = torch.bfloat16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)
batch, num_heads, seqlen, head_dim = 32, 16, 128, 32
q_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
k_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
v_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
attn_mask_shape = (1, seqlen)
attn_mask = make_tensor(attn_mask_shape)
attn_mask = attn_mask.expand(1, 1, seqlen, seqlen)
# test that we do not dispatch to onednn for an unsupported case
actual = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0]
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
@parametrize("type", ["dense"])
@parametrize("is_contiguous", [True, False])
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):