Compare commits

...

2 Commits

Author SHA1 Message Date
387beb48d5 fallback to math 2025-11-14 11:11:44 +00:00
a8f3556849 fallbacks memory efficient attention to overridable attention on XPU 2025-11-14 06:15:01 +00:00

View File

@ -86,11 +86,8 @@ bool can_use_cudnn_attention(sdp::sdp_params const& params, bool debug) {
return false;
}
bool can_use_mem_efficien_attention(sdp::sdp_params const& params, bool debug) {
if (debug) {
TORCH_WARN("XPU don't support SDPA mem efficient attention backend.");
}
return false;
bool can_use_mem_efficient_attention(sdp::sdp_params const& params, bool debug) {
return true;
}
bool priority_order_init = false;
@ -117,7 +114,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
auto& ctx = at::globalContext();
// use overridable linked to onednn as overridable implementation
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledOverrideableSDP() &&
!ctx.userEnabledFlashSDP()) {
!ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) {
return sdp::SDPBackend::error;
}
@ -156,8 +153,10 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
break;
case sdp::SDPBackend::efficient_attention:
if (ctx.userEnabledMemEfficientSDP() &&
can_use_mem_efficien_attention(kernel_params, print_debug)) {
TORCH_CHECK(false, "Invalid backend");
can_use_mem_efficient_attention(kernel_params, print_debug)) {
TORCH_WARN_ONCE(
"SDPA Memory Efficient Attention backend is not supported on XPU, falling back to math backend.");
return sdp::SDPBackend::math;
}
break;
default:
@ -178,7 +177,7 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
TORCH_WARN("CuDNN attention kernel not used because:");
can_use_cudnn_attention(kernel_params, print_debug);
TORCH_WARN("Memory Efficient attention kernel not used because:");
can_use_mem_efficien_attention(kernel_params, print_debug);
can_use_mem_efficient_attention(kernel_params, print_debug);
TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
return sdp::SDPBackend::error;
}