Compare commits

...

1 Commits

Author SHA1 Message Date
e3e06d9e4d support batch size=0 for sdpa 2025-10-28 07:49:52 -07:00
2 changed files with 40 additions and 5 deletions

View File

@ -22,6 +22,7 @@
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/reshape.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/sum.h>
@ -42,7 +43,6 @@ C10_DIAGNOSTIC_POP()
#include <static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#include <c10/util/Exception.h>
namespace FLASH_NAMESPACE {
@ -417,6 +417,26 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor out = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor q_padded = at::empty({0, seqlen_q, num_heads, head_size_og}, opts);
at::Tensor k_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor v_padded = at::empty({0, seqlen_k, num_heads_k, head_size_og}, opts);
at::Tensor softmax_lse = at::empty({0, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor p = at::empty({0}, opts);
if (return_softmax) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
p = at::empty({0, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
}
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), _unused, std::move(p)};
}
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
@ -547,7 +567,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
return {std::move(out), std::move(q_padded), std::move(k_padded), std::move(v_padded), std::move(softmax_lse), std::move(rng_state), std::move(_unused), std::move(p)};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -852,7 +872,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
@ -863,6 +882,20 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (batch_size == 0) {
auto opts = q.options();
at::Tensor dq = at::empty_like(q);
at::Tensor dk = at::empty_like(k);
at::Tensor dv = at::empty_like(v);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
at::Tensor softmax_d = at::empty({0, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
return {dq, dk, dv, softmax_d};
}
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");

View File

@ -1107,6 +1107,7 @@ class TestTransformers(NNTestCase):
)[0]
@tf32_on_and_off(0.003)
@parametrize("batch_size", [0, 5])
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
@ -1116,7 +1117,7 @@ class TestTransformers(NNTestCase):
if attn_dim is not None else "no_attn_mask")))
@parametrize("dropout_p", [0.0, 0.2, 0.5])
@sdpa_kernel(backends=[SDPBackend.MATH])
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
def test_scaled_dot_product_attention(self, device, batch_size, input_dim, attn_mask_dim, is_causal, dropout_p):
def sdp_ref(
q,
k,
@ -1140,12 +1141,13 @@ class TestTransformers(NNTestCase):
# TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
dtypes = [torch.double, torch.float]
for dtype in dtypes:
N = batch_size
def rand_tensor(*shape):
return torch.randn(shape, device=device, dtype=dtype)
# This test compares python and C++ implementations of SDP.
N, N_prime, L, S, E = 5, 2, 4, 3, 6
N_prime, L, S, E = 2, 4, 3, 6
if input_dim == 3:
query = rand_tensor(N, L, E)
key = rand_tensor(N, S, E)