Revert "Grouped Query Attention (#128898)"

This reverts commit d039b14207fe659d664c590efc06cc0a2abc96c0.

Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
This commit is contained in:
PyTorch MergeBot
2024-07-30 13:11:23 +00:00
parent bdf57da6a6
commit 499ead96ff
18 changed files with 169 additions and 370 deletions

View File

@ -14709,21 +14709,21 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
tags: nondeterministic_seeded
# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
dispatch:
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
variants: function
tags: nondeterministic_seeded

View File

@ -431,8 +431,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
}
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
auto backend = sdp::select_sdp_backend_cpp(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
@ -456,13 +456,12 @@ int64_t _fused_sdp_choice_meta(
const std::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
std::optional<double> scale) {
auto query_key_set = query_.key_set();
#if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) {
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
return choice_int;
}
#else
@ -476,8 +475,7 @@ int64_t _fused_sdp_choice_meta(
attn_mask_,
dropout_p,
is_causal,
scale,
enable_gqa);
scale);
return choice_int;
}
#endif
@ -610,36 +608,6 @@ bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tens
return any_inputs_require_grad && gradmode_enabled;
}
std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const bool enable_gqa) {
if (!enable_gqa) {
return std::make_tuple(key, value);
}
const auto q_num_heads = query.sym_size(-3);
const auto k_num_heads = key.sym_size(-3);
const auto v_num_heads = value.sym_size(-3);
bool all_equal = q_num_heads == k_num_heads && k_num_heads == v_num_heads;
bool key_divisible = q_num_heads % k_num_heads == 0;
bool value_divisible = q_num_heads % v_num_heads == 0;
TORCH_CHECK(all_equal || (key_divisible && value_divisible),
"Number of heads in key and value must divide the number of heads in ");
if (all_equal){
return std::make_tuple(key, value);
}
auto repeat_key_shape = query.sym_size(-3) / key.sym_size(-3);
auto repeat_value_shape = query.sym_size(-3) / value.sym_size(-3);
at::Tensor key_repeated = key.repeat_interleave_symint(repeat_key_shape, -3);
at::Tensor value_repeated = value.repeat_interleave_symint(repeat_value_shape, -3);
return std::make_tuple(std::move(key_repeated), std::move(value_repeated));
}
} // namespace
// Computes scaled dot product attention on query, key and value tensors, using
@ -678,13 +646,12 @@ Tensor scaled_dot_product_attention(
const std::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
std::optional<double> scale) {
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
}
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
@ -746,9 +713,8 @@ Tensor scaled_dot_product_attention(
attn_mask,
dropout_p,
is_causal,
c10::nullopt, /*dropout_mask*/
scale,
enable_gqa));
std::nullopt, /*dropout_mask*/
scale));
default:
TORCH_CHECK(
false,
@ -760,7 +726,7 @@ Tensor scaled_dot_product_attention(
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
const std::optional<Tensor>& dropout_mask, std::optional<double> scale, bool enable_gqa) {
const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
TORCH_CHECK(
@ -787,11 +753,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
}
// MQA/GQA handling
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key, value, enable_gqa);
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
if (attn_mask.has_value()) {
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
attn = attn.add(*attn_mask);
@ -807,13 +769,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
auto dropout_scaling = 1.0 / (1 - dropout_p);
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling), attn);
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
} else {
attn = at::dropout(attn, dropout_p, true);
}
}
return std::make_tuple(at::matmul(attn, value_expanded), attn);
return std::make_tuple(at::matmul(attn, value), attn);
}
std::tuple<at::Tensor, at::Tensor>

View File

@ -9,7 +9,7 @@ namespace at {
namespace native {
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale);
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);

View File

@ -868,8 +868,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
}
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(

View File

@ -598,7 +598,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense<true /*supports_grouped_query_attention=*/>,
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
@ -655,9 +655,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>,
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention=*/>);
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;

View File

@ -42,7 +42,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
check_nested_tensor,
check_for_dropout,
check_tensor_shapes,
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
check_batch_size_and_num_heads_dense,
check_attn_mask_shape,
check_head_dim_size_cpp,
check_nonzero_sequence_lengths_dense,

View File

@ -48,7 +48,6 @@ struct sdp_params {
std::optional<at::Tensor> attn_mask;
double dropout;
bool is_causal;
bool enable_gqa;
};
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
@ -354,46 +353,6 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
return true;
}
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
const auto q_num_heads = params.query.sym_size(-3);
const auto k_num_heads = params.key.sym_size(-3);
const auto v_num_heads = params.value.sym_size(-3);
const bool same_kv_heads = k_num_heads == v_num_heads;
if (!(same_kv_heads)){
if (debug) {
TORCH_WARN(
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
"Key sizes: ",
params.key.sizes(),
", Value sizes: ",
params.value.sizes(),
", Query sizes: ",
params.query.sizes(),
" instead.");
}
return false;
}
// Check if grouped query attention is supported and validate the number of
// heads
if (q_num_heads % k_num_heads != 0) {
if (debug) {
TORCH_WARN(
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
"Got input Key sizes(): ",
params.key.sym_size(-3),
", Value sizes(): ",
params.value.sym_size(-3),
", Query sizes(): ",
params.query.sym_size(-3),
" instead.");
}
return false;
}
return true;
}
template <bool supports_gqa>
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
@ -405,36 +364,16 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
auto q_num_heads = params.query.sym_size(1);
auto k_num_heads = params.key.sym_size(1);
auto v_num_heads = params.value.sym_size(1);
bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!same_batch_size){
if(debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
"Query.sizes(): ",
params.query.sizes(),
", Key.sizes(): ",
params.key.sizes(),
", Value.sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
if(params.enable_gqa && supports_gqa){
return check_grouped_query_attention(params, debug);
}
if (!same_num_heads){
if (!(same_batch_size && same_num_heads)) {
if (debug) {
TORCH_WARN(
"For dense input, both fused kernels require query, key and value to have the same num_heads. ",
"For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ",
"Query.sizes(): ",
params.query.sizes(),
", Key sizes(): ",
@ -445,7 +384,6 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
}
return false;
}
// If all checks pass, return true
return true;
}

View File

@ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1(
}
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}

View File

@ -31,7 +31,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
return SDPAParams(q, k, v, m, 0.1, True, False)
return SDPAParams(q, k, v, m, 0.1, True)
q = torch.randn(10)
k = torch.randn(10)
@ -39,7 +39,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10)
o = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
self.assertEqual(counter.frame_count, 1)
def test_graph_break_SDPAParams(self):
@ -48,7 +48,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(backend=counter)
def fn(q, k, v, m):
z = SDPAParams(q, k, v, m, 0.1, True, False)
z = SDPAParams(q, k, v, m, 0.1, True)
torch._dynamo.graph_break()
return z, q + 1
@ -58,7 +58,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10)
o, _ = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
self.assertEqual(counter.frame_count, 2)
def test_input_SDPAParams(self):
@ -74,7 +74,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
s = SDPAParams(q, k, v, m, 0.1, True, False)
s = SDPAParams(q, k, v, m, 0.1, True)
o, _ = fn(s, q)
self.assertIs(o, s)
self.assertEqual(counter.frame_count, 1)
@ -86,7 +86,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
q += 1
z = SDPAParams(q, k, v, m, 0.1, True, False)
z = SDPAParams(q, k, v, m, 0.1, True)
a = z.query
return a + 1, z, q
@ -95,7 +95,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
v = torch.randn(10)
m = torch.randn(10)
_, o, _ = fn(q, k, v, m)
expected = SDPAParams(q, k, v, m, 0.1, True, False)
expected = SDPAParams(q, k, v, m, 0.1, True)
self.assert_ref_equals_params(o, expected)
self.assertEqual(counter.frame_count, 1)

View File

@ -1561,36 +1561,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@onlyCUDA
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(fused_kernel):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCPU
@skipIfRocm # Nested Tensor
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
@ -1742,8 +1712,7 @@ class TestSDPAFailureModes(NNTestCase):
seq_len_list = [2, 4, 5, 6, 7]
shape = SdpaShape(5, 8, seq_len_list, 57)
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
q, k, v = make_tensor().transpose(1, 2), make_tensor().transpose(1, 2), make_tensor().transpose(1, 2)
q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
@ -1823,7 +1792,7 @@ class TestSDPAFailureModes(NNTestCase):
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
torch.nn.functional.scaled_dot_product_attention(
out = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@onlyCUDA
@ -2980,32 +2949,23 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dropout_p", [0.0, 0.22, 0.48])
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: List[int]):
scale: str):
if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
torch.cuda.empty_cache() # Prevent memory fragmentation
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
unittest.skip("Reference implementation OOM")
return
scale = scale if scale is None else (1 / head_dim)
num_heads_q = num_heads_kv = 4
if enable_gqa:
num_heads_q = n_heads[0]
num_heads_kv = n_heads[1]
query = torch.rand(batch_size, num_heads_q, seq_len_q, head_dim,
n_heads = 4
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
device=device, dtype=dtype, requires_grad=True)
key = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, device=device,
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
dtype=dtype, requires_grad=True)
value = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim,
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
device=device, dtype=dtype, requires_grad=True)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
@ -3015,15 +2975,14 @@ class TestSDPACudaOnly(NNTestCase):
if not is_dropout:
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference
out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
# Low Precision Math Reference
out_lp_ref = F.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
query, key, value, is_causal=is_causal, scale=scale)
else:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
@ -3050,12 +3009,11 @@ class TestSDPACudaOnly(NNTestCase):
dropout_mask = softmax_mask >= 0
# High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
dropout_mask=dropout_mask)[0]
upstream_grad = torch.rand_like(out, requires_grad=False)
@ -3075,7 +3033,7 @@ class TestSDPACudaOnly(NNTestCase):
'out': 1.5,
'grad_query': 13.0,
'grad_key': 2.0,
'grad_value': 1.75,
'grad_value': 1.5,
}
)
@ -3227,7 +3185,6 @@ class TestSDPACudaOnly(NNTestCase):
}
)
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

View File

@ -466,7 +466,6 @@ def gen_nn_functional(fm: FileManager) -> None:
"dropout_p: float = 0.0",
"is_causal: bool = False",
"scale: Optional[float] = None",
"enable_gqa: bool = False",
]
)
)

View File

@ -1956,7 +1956,6 @@ class _SDPAParams:
attn_mask: Optional[Tensor]
dropout: _float
is_causal: _bool
enable_gqa: _bool
def __init__(
self,
query: Tensor,
@ -1964,8 +1963,7 @@ class _SDPAParams:
value: Tensor,
attn_mask: Optional[Tensor],
dropout: _float,
is_causal: _bool,
enable_gqa: _bool) -> None: ...
is_causal: _bool) -> None: ...
class _SDPBackend(Enum):
ERROR = -1

View File

@ -33,9 +33,6 @@ class SDPAParamsVariable(VariableTracker):
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
@ -43,7 +40,6 @@ class SDPAParamsVariable(VariableTracker):
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}

View File

@ -1953,24 +1953,16 @@ Call this whenever a new thread is created in order to propagate values from
at::Tensor const& value,
std::optional<at::Tensor> attn_mask,
double dropout,
bool is_causal,
bool enable_gqa) {
bool is_causal) {
return sdp::sdp_params{
query,
key,
value,
std::move(attn_mask),
dropout,
is_causal,
enable_gqa};
query, key, value, std::move(attn_mask), dropout, is_causal};
}))
.def_readonly("query", &sdp::sdp_params::query)
.def_readonly("key", &sdp::sdp_params::key)
.def_readonly("value", &sdp::sdp_params::value)
.def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
.def_readonly("dropout", &sdp::sdp_params::dropout)
.def_readonly("is_causal", &sdp::sdp_params::is_causal)
.def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
.def_readonly("is_causal", &sdp::sdp_params::is_causal);
py::enum_<sdp::SDPBackend>(
py_module,

View File

@ -261,7 +261,7 @@ def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
@ -275,7 +275,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
@ -622,7 +622,6 @@ def jagged_scaled_dot_product_attention(
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
@ -653,7 +652,7 @@ def jagged_scaled_dot_product_attention(
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
query, key, value, attn_mask, dropout_p, is_causal
)
if backend_choice == SDPBackend.FLASH_ATTENTION:

View File

@ -173,7 +173,6 @@ class CausalBias(torch.Tensor):
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
r"""
Handles the logic for computing attention with the specified causal bias.
@ -190,7 +189,6 @@ class CausalBias(torch.Tensor):
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
@ -214,13 +212,10 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=enable_gqa,
)
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(
query, key, value, None, dropout_p, is_causal, enable_gqa
)
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0
og_head_size = query.size(-1)
@ -269,7 +264,6 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=enable_gqa,
)
else:
raise ValueError(

View File

@ -5606,21 +5606,20 @@ def _in_projection(
scaled_dot_product_attention = _add_docstr(
torch._C._nn.scaled_dot_product_attention,
r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> Tensor:
r"""
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor:
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
specified as a keyword argument.
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument.
.. code-block:: python
.. code-block:: python
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
@ -5631,22 +5630,17 @@ scaled_dot_product_attention = _add_docstr(
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_bias = attn_mask + attn_bias
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
.. warning::
This function is beta and subject to change.
.. warning:: This function is beta and subject to change.
.. warning::
.. warning::
This function always applies dropout according to the specified ``dropout_p`` argument.
To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
that makes the function call is not in training mode.
@ -5661,10 +5655,9 @@ scaled_dot_product_attention = _add_docstr(
self.p = p
def forward(self, ...):
return F.scaled_dot_product_attention(...,
dropout_p=(self.p if self.training else 0.0))
return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
Note:
Note:
There are currently three supported implementations of scaled dot product attention:
@ -5695,24 +5688,16 @@ scaled_dot_product_attention = _add_docstr(
The c++ implementation supports torch.float64 and can be used when higher precision is required.
For more information please see :doc:`/notes/numerical_accuracy`
Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
and math kernel on CUDA tensor, and does not support Nested tensor.
Constraints for GQA:
- number_of_heads_query % number_of_heads_key_value == 0 and,
- number_of_heads_key == number_of_heads_value
Note:
Note:
{cudnn_reproducibility_note}
""".format(
""".format(
**reproducibility_notes
)
+ r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
@ -5724,21 +5709,19 @@ scaled_dot_product_attention = _add_docstr(
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
Examples:
Examples:
>>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
@ -5748,21 +5731,12 @@ scaled_dot_product_attention = _add_docstr(
>>> F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3
>>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
>>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
https://arxiv.org/abs/2307.08691
.. _Memory-Efficient Attention:
.. _Memory-Efficient Attention:
https://github.com/facebookresearch/xformers
.. _Grouped-Query Attention:
https://arxiv.org/pdf/2305.13245
""",
""",
)

View File

@ -8688,7 +8688,6 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
num_heads_q_gqa, num_heads_kv_gqa = 32, 8
dim_3_q_shape = (batch, seq_q, head_dim)
dim_3_kv_shape = (batch, seq_kv, head_dim)
@ -8699,8 +8698,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
samples = []
for qkv_shape, is_causal, dropout_p, enable_gqa in product(
qkv_shapes, [True, False], [0.0, 0.5], [True, False]):
for qkv_shape, is_causal, dropout_p in product(
qkv_shapes, [True, False], [0.0, 0.5]):
shape_q, shape_kv = qkv_shape
samples.append(SampleInput(
make(shape_q),
@ -8730,15 +8729,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
dropout_p=0.0)
)
samples.append(
SampleInput(
make((batch, num_heads_q_gqa, seq_q, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
enable_gqa=True
)
)
yield from samples