[Intel GPU] Support SDPA backend selection and priority setting on XPU (#159464)

Currentlly SPDA XPU use own `priority_order` instead of the one from global context. Hence it does not support `with sdpa_kernel(order, set_priority=True)` with set_priority=True.

This PR enables this feature. To make default `priority_order` from global context works for XPU, I also move MATH backend to lowest priority, otherwise `cudnn attention` and `overrideable attention` will never be selected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159464
Approved by: https://github.com/guangyey, https://github.com/drisspg

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Co-authored-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
fengqing.lu
2025-08-14 08:55:28 +00:00
committed by PyTorch MergeBot
parent 089c4a1ba0
commit db763b1717
3 changed files with 77 additions and 31 deletions

View File

@ -1,3 +1,4 @@
#include <ATen/Context.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils.h>
@ -49,7 +50,7 @@ bool check_no_grad(sdp::sdp_params const& params, bool debug) {
return !any_inputs_require_grad || !gradmode_enabled;
}
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
@ -73,6 +74,42 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
return sdp::check_tensor_dtype(params, supported_dtypes, debug);
}
bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) {
// Currently, XPU fallbacks flash attention to overrideable
return can_use_overrideable_attention(params, debug);
}
bool can_use_cudnn_attention(sdp::sdp_params const& params, bool debug) {
if (debug) {
TORCH_WARN("XPU don't support SDPA cudnn attention backend.");
}
return false;
}
bool can_use_mem_efficien_attention(sdp::sdp_params const& params, bool debug) {
if (debug) {
TORCH_WARN("XPU don't support SDPA mem efficient attention backend.");
}
return false;
}
bool priority_order_init = false;
std::array<sdp::SDPBackend, sdp::num_backends> priority_order(
sdp::sdp_params const& params) {
if (!priority_order_init) {
priority_order_init = true;
const std::vector<int64_t> priority_order = {
static_cast<int64_t>(at::SDPBackend::overrideable),
static_cast<int64_t>(at::SDPBackend::math),
static_cast<int64_t>(at::SDPBackend::flash_attention),
static_cast<int64_t>(at::SDPBackend::efficient_attention),
static_cast<int64_t>(at::SDPBackend::cudnn_attention)};
at::globalContext().setSDPPriorityOrder(priority_order);
}
return at::globalContext().sDPPriorityOrder();
}
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
// This function defines the priority order of the different sdp backends
// 1. Flash Attention
@ -85,20 +122,16 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
}
// Get ideal kernel ordering
const std::array<sdp::SDPBackend, 3> priority_order{
sdp::SDPBackend::overrideable,
sdp::SDPBackend::math,
sdp::SDPBackend::flash_attention,
};
const auto ordering = priority_order(kernel_params);
// Because TORCHCHECK checks if condition is true we negate debug so that
// The statements will be printed when debug is true
bool print_debug = false;
for (auto& backend : priority_order) {
for (auto& backend : ordering) {
switch (backend) {
case sdp::SDPBackend::overrideable:
if (ctx.userEnabledOverrideableSDP() &&
use_overrideable_xpu(kernel_params, print_debug)) {
can_use_overrideable_attention(kernel_params, print_debug)) {
return sdp::SDPBackend::overrideable;
}
break;
@ -109,25 +142,43 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
break;
case sdp::SDPBackend::flash_attention:
if (ctx.userEnabledFlashSDP() &&
use_overrideable_xpu(kernel_params, print_debug)) {
TORCH_WARN(
"Flash Attention is not supported on XPU, falling back to overrideable kernel.");
can_use_flash_attention(kernel_params, print_debug)) {
TORCH_WARN_ONCE(
"SDPA Flash Attention backend is not supported on XPU, falling back to OVERRIDEABLE backend.");
return sdp::SDPBackend::overrideable;
}
break;
case sdp::SDPBackend::cudnn_attention:
if (ctx.userEnabledCuDNNSDP() &&
can_use_cudnn_attention(kernel_params, print_debug)) {
TORCH_CHECK(false, "Invalid backend");
}
break;
case sdp::SDPBackend::efficient_attention:
if (ctx.userEnabledMemEfficientSDP() &&
can_use_mem_efficien_attention(kernel_params, print_debug)) {
TORCH_CHECK(false, "Invalid backend");
}
break;
default:
TORCH_CHECK(false, "Invalid backend");
}
}
// If we have gotten to this point then two things have happened:
// 1. use_overrideable_xpu did not satisfy the constraints to be ran
// 1. can_use_overrideable_attention did not satisfy the constraints to be ran
// 2. The user has explicitly disabled the math kernel
// We then re-run the kernel checks with debug enabled to print out the
// reason why the kernel was not selected
print_debug = true;
TORCH_WARN("OneDNN kernel not used because:");
use_overrideable_xpu(kernel_params, print_debug);
TORCH_WARN("Flash attention kernel not used because:");
can_use_flash_attention(kernel_params, print_debug);
TORCH_WARN("Overrideable attention kernel not used because:");
can_use_overrideable_attention(kernel_params, print_debug);
TORCH_WARN("CuDNN attention kernel not used because:");
can_use_cudnn_attention(kernel_params, print_debug);
TORCH_WARN("Memory Efficient attention kernel not used because:");
can_use_mem_efficien_attention(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return sdp::SDPBackend::error;
}

View File

@ -4118,9 +4118,6 @@ class TestSDPACudaOnly(NNTestCase):
class TestSDPAXpuOnly(NNTestCase):
""" Used to test XPU only functionality of scaled_dot_product_attention
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
"""
@parametrize("type", ["dense"])
@ -4146,7 +4143,6 @@ class TestSDPAXpuOnly(NNTestCase):
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
# test that we do not dispatch to onednn for an unsupported case
actual = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@ -4184,7 +4180,6 @@ class TestSDPAXpuOnly(NNTestCase):
v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim)
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
# test that we do not dispatch to onednn for an unsupported case
actual = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True)
@ -4254,18 +4249,6 @@ class TestSDPAXpuOnly(NNTestCase):
for permute_order in permute_orders:
test_attention(list(permute_order) + [3])
def test_backends_flash_fallback_to_overrideable(self, device):
dtype = torch.bfloat16
q_shape = SdpaShape(1, 1, 8, 16)
kv_shape = SdpaShape(1, 1, 12, 16)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
q, k, v = make_q(), make_kv(), make_kv()
warning_str = "Flash Attention is not supported on XPU, falling back to overrideable kernel."
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, warning_str):
_ = F.scaled_dot_product_attention(q, k, v)
def test_backends_set_to_math(self, device):
dtype = torch.bfloat16
q_shape = SdpaShape(1, 1, 8, 16)
@ -4278,6 +4261,17 @@ class TestSDPAXpuOnly(NNTestCase):
self.assertFalse(torch._C._get_overrideable_sdp_enabled())
_ = F.scaled_dot_product_attention(q, k, v)
def test_default_priority_order(self, device):
# The default priority order of xpu is overrideable, math, flash, efficient, cudnn
# For xpu backend, we need to make sure that overrideable > math > flash
from torch.nn.attention import _cur_sdpa_kernel_backends
default_priority = _cur_sdpa_kernel_backends(with_priority=True)
flash_index = default_priority.index(SDPBackend.FLASH_ATTENTION)
overrideable_index = default_priority.index(SDPBackend.OVERRIDEABLE)
math_index = default_priority.index(SDPBackend.MATH)
self.assertTrue(overrideable_index < math_index < flash_index,
f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}")
def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device):
dtype = torch.bfloat16
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)

View File

@ -39,6 +39,7 @@ r"""An enum-like class that contains the different backends for scaled dot produ
- FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
- EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
- CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
- OVERRIDEABLE: The overridable backend for extension.
See :func:`torch.nn.attention.sdpa_kernel` for more details.