mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)
Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs New test added to check this when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py` One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958 Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d7f9f3aed
commit
cd62a73dcb
@ -1418,30 +1418,30 @@ void run_cudnn_SDP_fprop(
|
||||
}
|
||||
const fe::graph::Graph& mha_graph = *cache_it->second;
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{O, o.data_ptr()}};
|
||||
{O, o.mutable_data_ptr()}};
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[LSE] = softmaxstats.data_ptr();
|
||||
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[BIAS] = attn_bias.value().data_ptr();
|
||||
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
|
||||
}
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
|
||||
}
|
||||
}
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
@ -1538,29 +1538,30 @@ void run_cudnn_SDP_fprop_nestedtensor(
|
||||
|
||||
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
|
||||
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
|
||||
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
|
||||
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
|
||||
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
|
||||
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
|
||||
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
|
||||
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
|
||||
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
|
||||
auto rag_stats_off = cum_seqlen_q.mul(h_q);
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{O, o.data_ptr()},
|
||||
{RAG_Q_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_O_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
|
||||
{O, o.mutable_data_ptr()},
|
||||
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
|
||||
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[LSE] = softmaxstats.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr();
|
||||
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_stats_off.mutable_data_ptr();
|
||||
}
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK("bias not supported with nestedtensor");
|
||||
@ -1697,32 +1698,32 @@ void run_cudnn_SDP_bprop(
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
// inputs
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{O, o.data_ptr()},
|
||||
{DO, dO_.data_ptr()},
|
||||
{LSE, softmaxstats.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{O, o.mutable_data_ptr()},
|
||||
{DO, dO_.mutable_data_ptr()},
|
||||
{LSE, softmaxstats.mutable_data_ptr()},
|
||||
// outputs
|
||||
{DQ, dQ.data_ptr()},
|
||||
{DK, dK.data_ptr()},
|
||||
{DV, dV.data_ptr()},
|
||||
{DQ, dQ.mutable_data_ptr()},
|
||||
{DK, dK.mutable_data_ptr()},
|
||||
{DV, dV.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor}};
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[BIAS] = attn_bias.value().data_ptr();
|
||||
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
|
||||
}
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
|
||||
}
|
||||
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
@ -1773,9 +1774,10 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
|
||||
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
|
||||
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
|
||||
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
|
||||
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
|
||||
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
|
||||
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
|
||||
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
|
||||
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
|
||||
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
|
||||
auto rag_stats_off = cum_seqlen_q.mul(h_q);
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
@ -1842,27 +1844,27 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
// inputs
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{O, o.data_ptr()},
|
||||
{DO, dO_.data_ptr()},
|
||||
{LSE, softmaxstats.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{O, o.mutable_data_ptr()},
|
||||
{DO, dO_.mutable_data_ptr()},
|
||||
{LSE, softmaxstats.mutable_data_ptr()},
|
||||
// outputs
|
||||
{DQ, dQ.data_ptr()},
|
||||
{DK, dK.data_ptr()},
|
||||
{DV, dV.data_ptr()},
|
||||
{DQ, dQ.mutable_data_ptr()},
|
||||
{DK, dK.mutable_data_ptr()},
|
||||
{DV, dV.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{RAG_Q_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_O_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.data_ptr()},
|
||||
{RAG_LSE_OFF, rag_stats_off.data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
|
||||
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
|
||||
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
|
||||
{RAG_LSE_OFF, rag_stats_off.mutable_data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!attn_bias.has_value(),
|
||||
|
@ -637,13 +637,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
|
||||
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
|
||||
}
|
||||
return false;
|
||||
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
// Check that the input is nested
|
||||
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
|
||||
|
@ -2973,7 +2973,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED" not in os.environ, "cuDNN Nested Tensor support not enabled")
|
||||
@parametrize("type", ["nested"])
|
||||
@parametrize("is_contiguous", [True])
|
||||
@parametrize("is_contiguous", [True, False])
|
||||
def test_scaled_dot_product_attention_cudnn_nested(self, device, type: str, is_contiguous: bool):
|
||||
if TEST_WITH_ROCM and type == 'nested':
|
||||
self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
|
||||
|
Reference in New Issue
Block a user