mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Intel GPU] Support SDPA backend selection and priority setting on XPU (#159464)
Currentlly SPDA XPU use own `priority_order` instead of the one from global context. Hence it does not support `with sdpa_kernel(order, set_priority=True)` with set_priority=True. This PR enables this feature. To make default `priority_order` from global context works for XPU, I also move MATH backend to lowest priority, otherwise `cudnn attention` and `overrideable attention` will never be selected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159464 Approved by: https://github.com/guangyey, https://github.com/drisspg Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
089c4a1ba0
commit
db763b1717
@ -1,3 +1,4 @@
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils.h>
|
||||
@ -49,7 +50,7 @@ bool check_no_grad(sdp::sdp_params const& params, bool debug) {
|
||||
return !any_inputs_require_grad || !gradmode_enabled;
|
||||
}
|
||||
|
||||
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
|
||||
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
|
||||
constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
|
||||
at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
|
||||
|
||||
@ -73,6 +74,42 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
|
||||
return sdp::check_tensor_dtype(params, supported_dtypes, debug);
|
||||
}
|
||||
|
||||
bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) {
|
||||
// Currently, XPU fallbacks flash attention to overrideable
|
||||
return can_use_overrideable_attention(params, debug);
|
||||
}
|
||||
|
||||
bool can_use_cudnn_attention(sdp::sdp_params const& params, bool debug) {
|
||||
if (debug) {
|
||||
TORCH_WARN("XPU don't support SDPA cudnn attention backend.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool can_use_mem_efficien_attention(sdp::sdp_params const& params, bool debug) {
|
||||
if (debug) {
|
||||
TORCH_WARN("XPU don't support SDPA mem efficient attention backend.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool priority_order_init = false;
|
||||
|
||||
std::array<sdp::SDPBackend, sdp::num_backends> priority_order(
|
||||
sdp::sdp_params const& params) {
|
||||
if (!priority_order_init) {
|
||||
priority_order_init = true;
|
||||
const std::vector<int64_t> priority_order = {
|
||||
static_cast<int64_t>(at::SDPBackend::overrideable),
|
||||
static_cast<int64_t>(at::SDPBackend::math),
|
||||
static_cast<int64_t>(at::SDPBackend::flash_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::efficient_attention),
|
||||
static_cast<int64_t>(at::SDPBackend::cudnn_attention)};
|
||||
at::globalContext().setSDPPriorityOrder(priority_order);
|
||||
}
|
||||
return at::globalContext().sDPPriorityOrder();
|
||||
}
|
||||
|
||||
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
// This function defines the priority order of the different sdp backends
|
||||
// 1. Flash Attention
|
||||
@ -85,20 +122,16 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
}
|
||||
|
||||
// Get ideal kernel ordering
|
||||
const std::array<sdp::SDPBackend, 3> priority_order{
|
||||
sdp::SDPBackend::overrideable,
|
||||
sdp::SDPBackend::math,
|
||||
sdp::SDPBackend::flash_attention,
|
||||
};
|
||||
const auto ordering = priority_order(kernel_params);
|
||||
|
||||
// Because TORCHCHECK checks if condition is true we negate debug so that
|
||||
// The statements will be printed when debug is true
|
||||
bool print_debug = false;
|
||||
for (auto& backend : priority_order) {
|
||||
for (auto& backend : ordering) {
|
||||
switch (backend) {
|
||||
case sdp::SDPBackend::overrideable:
|
||||
if (ctx.userEnabledOverrideableSDP() &&
|
||||
use_overrideable_xpu(kernel_params, print_debug)) {
|
||||
can_use_overrideable_attention(kernel_params, print_debug)) {
|
||||
return sdp::SDPBackend::overrideable;
|
||||
}
|
||||
break;
|
||||
@ -109,25 +142,43 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
|
||||
break;
|
||||
case sdp::SDPBackend::flash_attention:
|
||||
if (ctx.userEnabledFlashSDP() &&
|
||||
use_overrideable_xpu(kernel_params, print_debug)) {
|
||||
TORCH_WARN(
|
||||
"Flash Attention is not supported on XPU, falling back to overrideable kernel.");
|
||||
can_use_flash_attention(kernel_params, print_debug)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"SDPA Flash Attention backend is not supported on XPU, falling back to OVERRIDEABLE backend.");
|
||||
return sdp::SDPBackend::overrideable;
|
||||
}
|
||||
break;
|
||||
case sdp::SDPBackend::cudnn_attention:
|
||||
if (ctx.userEnabledCuDNNSDP() &&
|
||||
can_use_cudnn_attention(kernel_params, print_debug)) {
|
||||
TORCH_CHECK(false, "Invalid backend");
|
||||
}
|
||||
break;
|
||||
case sdp::SDPBackend::efficient_attention:
|
||||
if (ctx.userEnabledMemEfficientSDP() &&
|
||||
can_use_mem_efficien_attention(kernel_params, print_debug)) {
|
||||
TORCH_CHECK(false, "Invalid backend");
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Invalid backend");
|
||||
}
|
||||
}
|
||||
// If we have gotten to this point then two things have happened:
|
||||
// 1. use_overrideable_xpu did not satisfy the constraints to be ran
|
||||
// 1. can_use_overrideable_attention did not satisfy the constraints to be ran
|
||||
// 2. The user has explicitly disabled the math kernel
|
||||
// We then re-run the kernel checks with debug enabled to print out the
|
||||
// reason why the kernel was not selected
|
||||
|
||||
print_debug = true;
|
||||
TORCH_WARN("OneDNN kernel not used because:");
|
||||
use_overrideable_xpu(kernel_params, print_debug);
|
||||
TORCH_WARN("Flash attention kernel not used because:");
|
||||
can_use_flash_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("Overrideable attention kernel not used because:");
|
||||
can_use_overrideable_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("CuDNN attention kernel not used because:");
|
||||
can_use_cudnn_attention(kernel_params, print_debug);
|
||||
TORCH_WARN("Memory Efficient attention kernel not used because:");
|
||||
can_use_mem_efficien_attention(kernel_params, print_debug);
|
||||
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
|
||||
return sdp::SDPBackend::error;
|
||||
}
|
||||
|
@ -4118,9 +4118,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
class TestSDPAXpuOnly(NNTestCase):
|
||||
""" Used to test XPU only functionality of scaled_dot_product_attention
|
||||
Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
|
||||
|
||||
Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
|
||||
math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
|
||||
"""
|
||||
|
||||
@parametrize("type", ["dense"])
|
||||
@ -4146,7 +4143,6 @@ class TestSDPAXpuOnly(NNTestCase):
|
||||
v_shape = SdpaShape(batch, num_heads, 2, head_dim_v)
|
||||
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
|
||||
|
||||
# test that we do not dispatch to onednn for an unsupported case
|
||||
actual = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@ -4184,7 +4180,6 @@ class TestSDPAXpuOnly(NNTestCase):
|
||||
v_shape = SdpaShape(batch_size, n_head_kv, kv_size, head_dim)
|
||||
query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape)
|
||||
|
||||
# test that we do not dispatch to onednn for an unsupported case
|
||||
actual = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True)
|
||||
|
||||
@ -4254,18 +4249,6 @@ class TestSDPAXpuOnly(NNTestCase):
|
||||
for permute_order in permute_orders:
|
||||
test_attention(list(permute_order) + [3])
|
||||
|
||||
def test_backends_flash_fallback_to_overrideable(self, device):
|
||||
dtype = torch.bfloat16
|
||||
q_shape = SdpaShape(1, 1, 8, 16)
|
||||
kv_shape = SdpaShape(1, 1, 12, 16)
|
||||
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
||||
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
||||
q, k, v = make_q(), make_kv(), make_kv()
|
||||
warning_str = "Flash Attention is not supported on XPU, falling back to overrideable kernel."
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
with self.assertWarnsRegex(UserWarning, warning_str):
|
||||
_ = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
def test_backends_set_to_math(self, device):
|
||||
dtype = torch.bfloat16
|
||||
q_shape = SdpaShape(1, 1, 8, 16)
|
||||
@ -4278,6 +4261,17 @@ class TestSDPAXpuOnly(NNTestCase):
|
||||
self.assertFalse(torch._C._get_overrideable_sdp_enabled())
|
||||
_ = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
def test_default_priority_order(self, device):
|
||||
# The default priority order of xpu is overrideable, math, flash, efficient, cudnn
|
||||
# For xpu backend, we need to make sure that overrideable > math > flash
|
||||
from torch.nn.attention import _cur_sdpa_kernel_backends
|
||||
default_priority = _cur_sdpa_kernel_backends(with_priority=True)
|
||||
flash_index = default_priority.index(SDPBackend.FLASH_ATTENTION)
|
||||
overrideable_index = default_priority.index(SDPBackend.OVERRIDEABLE)
|
||||
math_index = default_priority.index(SDPBackend.MATH)
|
||||
self.assertTrue(overrideable_index < math_index < flash_index,
|
||||
f"Expected overrideable < math < flash, got {overrideable_index}, {math_index}, {flash_index}")
|
||||
|
||||
def test_scaled_dot_product_attention_fused_kernels_safe_softmax(self, device):
|
||||
dtype = torch.bfloat16
|
||||
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False)
|
||||
|
@ -39,6 +39,7 @@ r"""An enum-like class that contains the different backends for scaled dot produ
|
||||
- FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
|
||||
- EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
|
||||
- CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
|
||||
- OVERRIDEABLE: The overridable backend for extension.
|
||||
|
||||
See :func:`torch.nn.attention.sdpa_kernel` for more details.
|
||||
|
||||
|
Reference in New Issue
Block a user