SDPA fix memory efficient attention for large batch dim (#154029)

Fixes #146704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154029
Approved by: https://github.com/ngimel
This commit is contained in:
Isalia20
2025-05-28 16:53:53 +00:00
committed by PyTorch MergeBot
parent 3b38989b5f
commit e313152a33
2 changed files with 110 additions and 24 deletions

View File

@ -46,6 +46,7 @@
#include <ATen/ops/_triton_multi_head_attention_native.h>
#include <ATen/ops/_triton_scaled_dot_attention.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/narrow_native.h>
@ -963,33 +964,98 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
std::optional<double> scale) {
// Used for tracking usage statistics
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
Tensor q_t = query.transpose(1, 2);
Tensor k_t = key.transpose(1, 2);
Tensor v_t = value.transpose(1, 2);
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
int64_t batch_size = query.size(0);
sdp::CustomMaskType custom_mask_type = is_causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
if (batch_size > MAX_BATCH_SIZE) {
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
}
auto process_chunk = [&](const Tensor& q_chunk,
const Tensor& k_chunk,
const Tensor& v_chunk,
const std::optional<Tensor>& bias_chunk)
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
Tensor q_t = q_chunk.transpose(1, 2);
Tensor k_t = k_chunk.transpose(1, 2);
Tensor v_t = v_chunk.transpose(1, 2);
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
q_t,
k_t,
v_t,
attn_bias,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
dropout_p,
static_cast<int64_t>(custom_mask_type),
compute_log_sumexp,
scale);
sdp::CustomMaskType custom_mask_type = is_causal
? sdp::CustomMaskType::CausalFromTopLeft
: sdp::CustomMaskType::NoCustomMask;
attention = attention.transpose(1, 2);
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] =
at::_efficient_attention_forward(
q_t,
k_t,
v_t,
bias_chunk,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
dropout_p,
static_cast<int64_t>(custom_mask_type),
compute_log_sumexp,
scale);
attention = attention.transpose(1, 2);
return std::make_tuple(std::move(attention),
std::move(log_sumexp),
std::move(seed),
std::move(offset));
};
// when bs is larger than allowed maximum, process in chunks
if (batch_size > MAX_BATCH_SIZE) {
int64_t start = 0;
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
Tensor query_chunk = query.slice(0, start, end);
Tensor key_chunk = key.slice(0, start, end);
Tensor value_chunk = value.slice(0, start, end);
std::optional<Tensor> bias_chunk;
if (attn_bias.has_value()) {
bias_chunk = attn_bias.value().slice(0, start, end);
}
auto [attn, log_sumexp, seed, offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
int dim = attn.dim();
std::vector<int64_t> sizes;
sizes.reserve(dim);
sizes.push_back(batch_size);
for (int i = 1; i < dim; i++) {
sizes.push_back(attn.size(i));
}
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
final_attention.slice(0, start, end).copy_(attn);
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
end = std::min(start + MAX_BATCH_SIZE, batch_size);
query_chunk = query.slice(0, start, end);
key_chunk = key.slice(0, start, end);
value_chunk = value.slice(0, start, end);
if (attn_bias.has_value()) {
bias_chunk = attn_bias.value().slice(0, start, end);
} else {
bias_chunk.reset();
}
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
final_attention.slice(0, start, end).copy_(chunk_attn);
}
return std::make_tuple(std::move(final_attention),
std::move(log_sumexp),
std::move(seed),
std::move(offset));
}
// when bs is within the allowed size, no need to chunk it
else {
return process_chunk(query, key, value, attn_bias);
}
}
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,

View File

@ -1898,6 +1898,26 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
out = F.scaled_dot_product_attention(query, key, value)
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
@onlyCUDA
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
error_str = (r"Efficient attention cannot produce valid seed, "
r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.")
with self.assertRaisesRegex(RuntimeError, error_str):
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
assert head_dim <= 256