mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
SDPA fix memory efficient attention for large batch dim (#154029)
Fixes #146704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154029 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
3b38989b5f
commit
e313152a33
@ -46,6 +46,7 @@
|
||||
#include <ATen/ops/_triton_multi_head_attention_native.h>
|
||||
#include <ATen/ops/_triton_scaled_dot_attention.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/narrow_native.h>
|
||||
@ -963,33 +964,98 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||
std::optional<double> scale) {
|
||||
// Used for tracking usage statistics
|
||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention");
|
||||
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
|
||||
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
Tensor q_t = query.transpose(1, 2);
|
||||
Tensor k_t = key.transpose(1, 2);
|
||||
Tensor v_t = value.transpose(1, 2);
|
||||
constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1;
|
||||
int64_t batch_size = query.size(0);
|
||||
|
||||
sdp::CustomMaskType custom_mask_type = is_causal
|
||||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
TORCH_CHECK(!compute_log_sumexp && (dropout_p == 0.0),
|
||||
"Efficient attention cannot produce valid seed, logsumexp and offset outputs when "
|
||||
"the batch size exceeds (", MAX_BATCH_SIZE, ").");
|
||||
}
|
||||
auto process_chunk = [&](const Tensor& q_chunk,
|
||||
const Tensor& k_chunk,
|
||||
const Tensor& v_chunk,
|
||||
const std::optional<Tensor>& bias_chunk)
|
||||
-> std::tuple<Tensor, Tensor, Tensor, Tensor> {
|
||||
Tensor q_t = q_chunk.transpose(1, 2);
|
||||
Tensor k_t = k_chunk.transpose(1, 2);
|
||||
Tensor v_t = v_chunk.transpose(1, 2);
|
||||
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficient_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
attn_bias,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
dropout_p,
|
||||
static_cast<int64_t>(custom_mask_type),
|
||||
compute_log_sumexp,
|
||||
scale);
|
||||
sdp::CustomMaskType custom_mask_type = is_causal
|
||||
? sdp::CustomMaskType::CausalFromTopLeft
|
||||
: sdp::CustomMaskType::NoCustomMask;
|
||||
|
||||
attention = attention.transpose(1, 2);
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
|
||||
auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] =
|
||||
at::_efficient_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
bias_chunk,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
dropout_p,
|
||||
static_cast<int64_t>(custom_mask_type),
|
||||
compute_log_sumexp,
|
||||
scale);
|
||||
attention = attention.transpose(1, 2);
|
||||
|
||||
return std::make_tuple(std::move(attention),
|
||||
std::move(log_sumexp),
|
||||
std::move(seed),
|
||||
std::move(offset));
|
||||
};
|
||||
|
||||
// when bs is larger than allowed maximum, process in chunks
|
||||
if (batch_size > MAX_BATCH_SIZE) {
|
||||
int64_t start = 0;
|
||||
int64_t end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
|
||||
Tensor query_chunk = query.slice(0, start, end);
|
||||
Tensor key_chunk = key.slice(0, start, end);
|
||||
Tensor value_chunk = value.slice(0, start, end);
|
||||
std::optional<Tensor> bias_chunk;
|
||||
if (attn_bias.has_value()) {
|
||||
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||
}
|
||||
auto [attn, log_sumexp, seed, offset] =
|
||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||
int dim = attn.dim();
|
||||
std::vector<int64_t> sizes;
|
||||
sizes.reserve(dim);
|
||||
sizes.push_back(batch_size);
|
||||
for (int i = 1; i < dim; i++) {
|
||||
sizes.push_back(attn.size(i));
|
||||
}
|
||||
Tensor final_attention = at::empty_strided(sizes, attn.strides(), attn.options());
|
||||
final_attention.slice(0, start, end).copy_(attn);
|
||||
|
||||
for (start = end; start < batch_size; start += MAX_BATCH_SIZE) {
|
||||
end = std::min(start + MAX_BATCH_SIZE, batch_size);
|
||||
query_chunk = query.slice(0, start, end);
|
||||
key_chunk = key.slice(0, start, end);
|
||||
value_chunk = value.slice(0, start, end);
|
||||
if (attn_bias.has_value()) {
|
||||
bias_chunk = attn_bias.value().slice(0, start, end);
|
||||
} else {
|
||||
bias_chunk.reset();
|
||||
}
|
||||
|
||||
auto [chunk_attn, chunk_log_sumexp, chunk_seed, chunk_offset] =
|
||||
process_chunk(query_chunk, key_chunk, value_chunk, bias_chunk);
|
||||
final_attention.slice(0, start, end).copy_(chunk_attn);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(final_attention),
|
||||
std::move(log_sumexp),
|
||||
std::move(seed),
|
||||
std::move(offset));
|
||||
}
|
||||
// when bs is within the allowed size, no need to chunk it
|
||||
else {
|
||||
return process_chunk(query, key, value, attn_bias);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
|
@ -1898,6 +1898,26 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
with sdpa_kernel(backends=SDPBackend.EFFICIENT_ATTENTION):
|
||||
out = F.scaled_dot_product_attention(query, key, value)
|
||||
out_cpu = F.scaled_dot_product_attention(query.cpu(), key.cpu(), value.cpu())
|
||||
self.assertEqual(out, out_cpu, atol=1e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
value = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
error_str = (r"Efficient attention cannot produce valid seed, "
|
||||
r"logsumexp and offset outputs when the batch size exceeds \(65535\)\.")
|
||||
with self.assertRaisesRegex(RuntimeError, error_str):
|
||||
torch._scaled_dot_product_efficient_attention(query, key, value, attn_bias=None, compute_log_sumexp=True)
|
||||
|
||||
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
assert head_dim <= 256
|
||||
|
Reference in New Issue
Block a user