[cuDNN][SDPA] Handle noncontig nested tensors in cuDNN SDPA (#164958)

Previously we hardcoded the assumption in cuDNN that the inputs would be dense which breaks when e.g., the user is chunking tensors yielding noncontig inputs

New test added to check this  when `TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED=1` is set in `test/test_transformers.py`

One issue I noticed was that the old gating of nested tensor in `sdp_utils.cpp` seems to be a no-op? All of the inputs are reported as "dense" by the time that function is called in the nested tensor tests in `test/test_nestedtensor.py -k sdpa`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164958
Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
Eddie Yan
2025-10-09 21:58:54 +00:00
committed by PyTorch MergeBot
parent 4d7f9f3aed
commit cd62a73dcb
3 changed files with 75 additions and 79 deletions

View File

@ -1418,30 +1418,30 @@ void run_cudnn_SDP_fprop(
}
const fe::graph::Graph& mha_graph = *cache_it->second;
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()}};
{O, o.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
if (return_softmaxstats) {
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
}
auto workspace_size = mha_graph.get_workspace_size();
@ -1538,29 +1538,30 @@ void run_cudnn_SDP_fprop_nestedtensor(
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
std::unordered_map<int64_t, void*> variant_pack = {
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{SCALE, &scaling_factor},
{O, o.data_ptr()},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{O, o.mutable_data_ptr()},
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (return_softmaxstats) {
variant_pack[LSE] = softmaxstats.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr();
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_stats_off.mutable_data_ptr();
}
if (dropout_probability != 0.0f) {
variant_pack[SEED] = dropoutseed.data_ptr();
variant_pack[OFFSET] = dropoutoffset.data_ptr();
variant_pack[SEED] = dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
TORCH_CHECK("bias not supported with nestedtensor");
@ -1697,32 +1698,32 @@ void run_cudnn_SDP_bprop(
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
if (attn_bias.has_value()) {
variant_pack[BIAS] = attn_bias.value().data_ptr();
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
}
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
}
auto workspace_size = mha_graph.get_workspace_size();
@ -1773,9 +1774,10 @@ void run_cudnn_SDP_bprop_nestedtensor(
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_stats_off = cum_seqlen_q.mul(h_q);
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -1842,27 +1844,27 @@ void run_cudnn_SDP_bprop_nestedtensor(
std::unordered_map<int64_t, void*> variant_pack = {
// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{DO, dO_.data_ptr()},
{LSE, softmaxstats.data_ptr()},
{Q, q.mutable_data_ptr()},
{K, k.mutable_data_ptr()},
{V, v.mutable_data_ptr()},
{O, o.mutable_data_ptr()},
{DO, dO_.mutable_data_ptr()},
{LSE, softmaxstats.mutable_data_ptr()},
// outputs
{DQ, dQ.data_ptr()},
{DK, dK.data_ptr()},
{DV, dV.data_ptr()},
{DQ, dQ.mutable_data_ptr()},
{DK, dK.mutable_data_ptr()},
{DV, dV.mutable_data_ptr()},
{SCALE, &scaling_factor},
{RAG_Q_OFF, rag_q_off.data_ptr()},
{RAG_O_OFF, rag_q_off.data_ptr()},
{RAG_K_OFF, rag_k_off.data_ptr()},
{RAG_V_OFF, rag_v_off.data_ptr()},
{RAG_LSE_OFF, rag_stats_off.data_ptr()},
{SEQ_LEN_Q, seqlen_q.data_ptr()},
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
{RAG_LSE_OFF, rag_stats_off.mutable_data_ptr()},
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
if (dropout_probability != 0.0f) {
variant_pack[SEED] = _dropoutseed.data_ptr();
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
}
TORCH_CHECK(
!attn_bias.has_value(),

View File

@ -637,13 +637,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
}
return false;
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
if (debug) {
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
return false;
}
}
const auto dprop = at::cuda::getCurrentDeviceProperties();
// Check that the input is nested
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {

View File

@ -2973,7 +2973,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "Fused SDPA was not built for this system")
@unittest.skipIf("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED" not in os.environ, "cuDNN Nested Tensor support not enabled")
@parametrize("type", ["nested"])
@parametrize("is_contiguous", [True])
@parametrize("is_contiguous", [True, False])
def test_scaled_dot_product_attention_cudnn_nested(self, device, type: str, is_contiguous: bool):
if TEST_WITH_ROCM and type == 'nested':
self.skipTest("ROCM does not support efficient attention on nested tensors, for now")