mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Masahiro Tanaka <mtanaka@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
75 lines
1.9 KiB
C++
75 lines
1.9 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// DeepSpeed Team
|
|
|
|
/******************************************************************************
|
|
Copyright (c) 2023, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda.h>
|
|
#include <vector>
|
|
|
|
#include "attention_atom.h"
|
|
|
|
constexpr int TOTAL_DIM = 0;
|
|
constexpr int H_DIM = 1;
|
|
constexpr int D_DIM = 2;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Qkv_params {
|
|
using index_t = uint32_t;
|
|
// The QKV matrices.
|
|
void* __restrict__ q_ptr;
|
|
void* __restrict__ k_ptr;
|
|
void* __restrict__ v_ptr;
|
|
|
|
// The stride between rows of the Q, K and V matrices.
|
|
index_t q_row_stride;
|
|
index_t k_row_stride;
|
|
index_t v_row_stride;
|
|
index_t q_head_stride;
|
|
index_t k_head_stride;
|
|
index_t v_head_stride;
|
|
|
|
// The number of heads.
|
|
int h, h_k;
|
|
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
|
// different from nheads (query).
|
|
int h_h_k_ratio; // precompute h / h_k,
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Flash_fwd_params : public Qkv_params {
|
|
// The O matrix (output).
|
|
void* __restrict__ o_ptr;
|
|
|
|
// The attention metadata
|
|
AttentionAtom* __restrict__ atoms;
|
|
|
|
// Total attention atoms
|
|
int num_atoms;
|
|
|
|
// The stride between rows of O.
|
|
index_t o_row_stride;
|
|
index_t o_head_stride;
|
|
|
|
// The dimensions
|
|
int d, d_rounded;
|
|
|
|
// The scaling factors for the kernel.
|
|
float scale_softmax;
|
|
float scale_softmax_log2;
|
|
|
|
bool is_bf16;
|
|
bool is_causal;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream);
|