Compare commits

...

1 Commits

Author SHA1 Message Date
54792a0100 debug 2025-04-10 17:41:31 -07:00

View File

@ -9,6 +9,7 @@
#include <tuple>
#ifdef USE_FLASH_ATTENTION
#include <ATen/core/Tensor.h>
@ -39,7 +40,7 @@
#include <namespace_config.h>
#include <static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#include <torch/serialize.h>
#include <c10/util/Exception.h>
@ -537,6 +538,18 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
torch::save(q, "/home/danvm/pytorch/fwd_Q.pt");
torch::save(k, "/home/danvm/pytorch/fwd_K.pt");
torch::save(v, "/home/danvm/pytorch/fwd_V.pt");
// std::cout << "Q fwd: " << q << std::endl;
// std::cout << "K fwd: " << k << std::endl;
// std::cout << "V fwd: " << v << std::endl;
TORCH_CHECK(q.isnan().any().item<bool>() == true, "fwd Q han NaN");
TORCH_CHECK(k.isnan().any().item<bool>() == true, "fwd K han NaN");
TORCH_CHECK(v.isnan().any().item<bool>() == true, "fwd V han NaN");
if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
@ -934,6 +947,18 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv = at::empty_like(v);
}
torch::save(q, "/home/danvm/pytorch/bwd_Q.pt");
torch::save(k, "/home/danvm/pytorch/bwd_K.pt");
torch::save(v, "/home/danvm/pytorch/bwd_V.pt");
// std::cout << "Q bwd: " << q << std::endl;
// std::cout << "K bwd: " << k << std::endl;
// std::cout << "V bwd: " << v << std::endl;
TORCH_CHECK(q.isnan().any().item<bool>() == true, "q contains NaN at start of backward");
TORCH_CHECK(k.isnan().any().item<bool>() == true, "k contains NaN at start of backward");
TORCH_CHECK(v.isnan().any().item<bool>() == true, "v contains NaN at start of backward");
// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
bool loop = true;
@ -1015,6 +1040,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
// For MQA/GQA we need to sum dK and dV across the groups
TORCH_CHECK(dq.isnan().any().item<bool>(), "dq has Nan after backward")
TORCH_CHECK(dk.isnan().any().item<bool>(), "dq has Nan after backward")
TORCH_CHECK(dv.isnan().any().item<bool>(), "dq has Nan after backward")
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
@ -1542,3 +1571,4 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
} // namespace pytorch_fmha
#endif