mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Intel GPU] undo broadcast on zero stride tensor for SDPA (#151976)
Fix https://github.com/pytorch/pytorch/issues/152290. The model **hubert** uses aten::expand to build attention mask by broadcasting. Pytorch uses strides[d]=0 to represent broadcast, which is not supported by oneDNN. This PR handles this scenario. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151976 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f48bab377
commit
de92296bbb
@ -49,16 +49,39 @@ struct SDPALogicalParams {
|
|||||||
"Only FP16/BF16/FP32 datatypes are currently supported");
|
"Only FP16/BF16/FP32 datatypes are currently supported");
|
||||||
const dims scalar_shape = {1};
|
const dims scalar_shape = {1};
|
||||||
std::vector<logical_tensor> inputLogicalTensors;
|
std::vector<logical_tensor> inputLogicalTensors;
|
||||||
|
|
||||||
|
at::Tensor reshaped_query = query_;
|
||||||
|
at::Tensor reshaped_key = key_;
|
||||||
|
at::Tensor reshaped_value = value_;
|
||||||
|
at::Tensor reshaped_output = output_;
|
||||||
|
at::Tensor reshaped_attn_mask = attn_mask_.value_or(at::Tensor());
|
||||||
|
if (at::native::onednn::is_broadcast(reshaped_query)) {
|
||||||
|
at::native::onednn::undo_broadcast(reshaped_query);
|
||||||
|
}
|
||||||
|
if (at::native::onednn::is_broadcast(reshaped_key)) {
|
||||||
|
at::native::onednn::undo_broadcast(reshaped_key);
|
||||||
|
}
|
||||||
|
if (at::native::onednn::is_broadcast(reshaped_value)) {
|
||||||
|
at::native::onednn::undo_broadcast(reshaped_value);
|
||||||
|
}
|
||||||
|
if (at::native::onednn::is_broadcast(reshaped_output)) {
|
||||||
|
at::native::onednn::undo_broadcast(reshaped_output);
|
||||||
|
}
|
||||||
|
if (attn_mask_.has_value() &&
|
||||||
|
at::native::onednn::is_broadcast(reshaped_attn_mask)) {
|
||||||
|
at::native::onednn::undo_broadcast(reshaped_attn_mask);
|
||||||
|
}
|
||||||
|
|
||||||
query = {
|
query = {
|
||||||
static_cast<size_t>(TensorID::query),
|
static_cast<size_t>(TensorID::query),
|
||||||
dtype,
|
dtype,
|
||||||
query_.sizes().vec(),
|
reshaped_query.sizes().vec(),
|
||||||
query_.strides().vec()};
|
reshaped_query.strides().vec()};
|
||||||
key = {
|
key = {
|
||||||
static_cast<size_t>(TensorID::key),
|
static_cast<size_t>(TensorID::key),
|
||||||
dtype,
|
dtype,
|
||||||
key_.sizes().vec(),
|
reshaped_key.sizes().vec(),
|
||||||
key_.strides().vec()};
|
reshaped_key.strides().vec()};
|
||||||
scale = {
|
scale = {
|
||||||
static_cast<size_t>(TensorID::scale),
|
static_cast<size_t>(TensorID::scale),
|
||||||
dtype,
|
dtype,
|
||||||
@ -77,19 +100,19 @@ struct SDPALogicalParams {
|
|||||||
attn_mask = {
|
attn_mask = {
|
||||||
static_cast<size_t>(TensorID::attn_mask),
|
static_cast<size_t>(TensorID::attn_mask),
|
||||||
dtype,
|
dtype,
|
||||||
attn_mask_->sizes().vec(),
|
reshaped_attn_mask.sizes().vec(),
|
||||||
attn_mask_->strides().vec()};
|
reshaped_attn_mask.strides().vec()};
|
||||||
}
|
}
|
||||||
value = {
|
value = {
|
||||||
static_cast<size_t>(TensorID::value),
|
static_cast<size_t>(TensorID::value),
|
||||||
dtype,
|
dtype,
|
||||||
value_.sizes().vec(),
|
reshaped_value.sizes().vec(),
|
||||||
value_.strides().vec()};
|
reshaped_value.strides().vec()};
|
||||||
output = {
|
output = {
|
||||||
static_cast<size_t>(TensorID::output),
|
static_cast<size_t>(TensorID::output),
|
||||||
dtype,
|
dtype,
|
||||||
output_.sizes().vec(),
|
reshaped_output.sizes().vec(),
|
||||||
output_.strides().vec()};
|
reshaped_output.strides().vec()};
|
||||||
}
|
}
|
||||||
std::vector<logical_tensor> get_input() const {
|
std::vector<logical_tensor> get_input() const {
|
||||||
std::vector<logical_tensor> input = {query, key, scale};
|
std::vector<logical_tensor> input = {query, key, scale};
|
||||||
|
@ -4006,6 +4006,28 @@ class TestSDPAXpuOnly(NNTestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
||||||
_ = F.scaled_dot_product_attention(q, k, v)
|
_ = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
def test_fused_attention_broadcasted_input(self, device):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)
|
||||||
|
batch, num_heads, seqlen, head_dim = 32, 16, 128, 32
|
||||||
|
q_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
|
||||||
|
k_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
|
||||||
|
v_shape = SdpaShape(batch, num_heads, seqlen, head_dim)
|
||||||
|
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
|
||||||
|
|
||||||
|
attn_mask_shape = (1, seqlen)
|
||||||
|
attn_mask = make_tensor(attn_mask_shape)
|
||||||
|
attn_mask = attn_mask.expand(1, 1, seqlen, seqlen)
|
||||||
|
|
||||||
|
# test that we do not dispatch to onednn for an unsupported case
|
||||||
|
actual = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
math_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||||
|
query.float(), key.float(), value.float(), attn_mask=attn_mask, dropout_p=0.0, is_causal=False)[0]
|
||||||
|
|
||||||
|
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
@parametrize("type", ["dense"])
|
@parametrize("type", ["dense"])
|
||||||
@parametrize("is_contiguous", [True, False])
|
@parametrize("is_contiguous", [True, False])
|
||||||
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
|
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
|
||||||
|
Reference in New Issue
Block a user