mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 15:43:52 +08:00
Compare commits
4 Commits
v0.4.3
...
fix-hashin
Author | SHA1 | Date | |
---|---|---|---|
1936d7bab0 | |||
996cf2de5c | |||
260d119e86 | |||
a360ff80bb |
@ -33,20 +33,27 @@
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass It's beem modified to support either
|
||||
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
|
||||
// Important because this saves us a factor 4x on the number of kernels
|
||||
// compiled.
|
||||
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either
|
||||
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||
// all cases of per-tensor or per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graph
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
// clang-format on
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
@ -59,9 +66,11 @@ template<
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
Element null_default = Element(0);
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
@ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast {
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row) {
|
||||
if (params_ptr->row_broadcast) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard);
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default;
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if(get<1>(coord_v(i)) < n)
|
||||
{
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
@ -208,9 +217,11 @@ template<
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
Element null_default = Element(0);
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
@ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast {
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
@ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast {
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
@ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->ptr_col) {
|
||||
if (params_ptr->col_broadcast) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
@ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if(pred(i)){
|
||||
dst_v(i) = params_ptr->null_default;
|
||||
if (pred(i)) {
|
||||
dst_v(i) = *(params_ptr->ptr_col);
|
||||
}
|
||||
}
|
||||
}
|
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
389
csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
Normal file
@ -0,0 +1,389 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||
|
||||
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||
struct SharedStorage {
|
||||
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params),
|
||||
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||
|
||||
Params params;
|
||||
Element* smem_row;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <int EpiTiles, class GTensor, class STensor>
|
||||
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||
: gRow(cute::forward<GTensor>(gRow)),
|
||||
sRow(cute::forward<STensor>(sRow)),
|
||||
params(params) {}
|
||||
|
||||
GTensor gRow; // (CTA_M,CTA_N)
|
||||
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||
if (params.ptr_row == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (issue_tma_load) {
|
||||
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||
// Issue the TMA bulk copy
|
||||
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||
cute::move(gRow), cute::move(sRow), params);
|
||||
}
|
||||
|
||||
template <int EpiTiles, class RTensor, class STensor>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||
params(params) {}
|
||||
|
||||
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tCrRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
||||
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_aligned(filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
||||
cute::move(tCgCol), cute::move(tCrCol), params);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -22,7 +22,7 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "cutlass_visitor_2x_broadcast_epilogue.hpp"
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
@ -145,17 +145,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||
|
||||
// If A and B are quantized per-tensor, then these scale tensors are scalars,
|
||||
// and they are passed in via the second argument.
|
||||
using ScaleAArgs = typename Gemm::ScaleA::Arguments;
|
||||
ScaleAArgs a_args = a_scales.numel() == 1
|
||||
? ScaleAArgs{nullptr, a_scales.item<float>(), {}}
|
||||
: ScaleAArgs{a_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
using ScaleBArgs = typename Gemm::ScaleB::Arguments;
|
||||
ScaleBArgs b_args = b_scales.numel() == 1
|
||||
? ScaleBArgs{nullptr, b_scales.item<float>(), {}}
|
||||
: ScaleBArgs{b_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
|
@ -18,11 +18,14 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "broadcast_load_epilogue_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
@ -65,7 +68,7 @@ struct cutlass_3x_gemm {
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
@ -73,7 +76,7 @@ struct cutlass_3x_gemm {
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
@ -166,13 +169,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
using ScaleA_Args = typename Gemm::ScaleA::Arguments;
|
||||
using ScaleB_Args = typename Gemm::ScaleB::Arguments;
|
||||
ScaleA_Args a_args = a_scales.numel() == 1
|
||||
? ScaleA_Args{nullptr, a_scales.item<float>(), {}}
|
||||
: ScaleA_Args{a_scales.data_ptr<float>(), {}, {}};
|
||||
|
||||
ScaleB_Args b_args = b_scales.numel() == 1
|
||||
? ScaleB_Args{nullptr, b_scales.item<float>(), {}}
|
||||
: ScaleB_Args{b_scales.data_ptr<float>(), {}, {}};
|
||||
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
|
||||
args.epilogue.thread = {a_args, {b_args}};
|
||||
|
||||
@ -182,10 +181,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
TORCH_CHECK(workspace_size == 0);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
cutlass::Status status = gemm_op.run(args, stream);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -59,7 +59,7 @@ exclude = [
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies"
|
||||
ignore-words-list = "dout, te, indicies, subtile"
|
||||
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
|
||||
[tool.isort]
|
||||
|
17
setup.py
17
setup.py
@ -187,19 +187,22 @@ class cmake_build_ext(build_ext):
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
|
||||
targets = []
|
||||
# Build all the extensions
|
||||
for ext in self.extensions:
|
||||
self.configure(ext)
|
||||
targets.append(remove_prefix(ext.name, "vllm."))
|
||||
|
||||
ext_target_name = remove_prefix(ext.name, "vllm.")
|
||||
num_jobs, _ = self.compute_num_jobs()
|
||||
num_jobs, _ = self.compute_num_jobs()
|
||||
|
||||
build_args = [
|
||||
'--build', '.', '--target', ext_target_name, '-j',
|
||||
str(num_jobs)
|
||||
]
|
||||
build_args = [
|
||||
"--build",
|
||||
".",
|
||||
f"-j={num_jobs}",
|
||||
*[f"--target={name}" for name in targets],
|
||||
]
|
||||
|
||||
subprocess.check_call(['cmake', *build_args], cwd=self.build_temp)
|
||||
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
|
@ -207,14 +207,21 @@ class CutlassLayer(torch.nn.Module):
|
||||
self.out_dtype)
|
||||
|
||||
|
||||
def test_cutlass_cuda_graph():
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
a = to_int8(torch.randn((m, k), device="cuda"))
|
||||
b = to_int8(torch.randn((n, k), device="cuda").t())
|
||||
|
||||
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10)
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Token blocks."""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -25,6 +25,7 @@ class LogicalTokenBlock:
|
||||
|
||||
self.token_ids = [_BLANK_TOKEN_ID] * block_size
|
||||
self.num_tokens = 0
|
||||
self.block_hash: Optional[int] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.num_tokens == 0
|
||||
|
@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
self.cross_block_tables: Dict[str, BlockTable] = {}
|
||||
|
||||
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
|
||||
return 0 if seq is None \
|
||||
else len(seq.logical_token_blocks)
|
||||
return 0 if seq is None else len(seq.logical_token_blocks)
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
@ -275,8 +274,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
|
||||
cross_num_required_blocks = self._get_seq_num_required_blocks(
|
||||
seq_group.get_encoder_seq())
|
||||
num_required_blocks = self_num_required_blocks + \
|
||||
cross_num_required_blocks
|
||||
num_required_blocks = (self_num_required_blocks +
|
||||
cross_num_required_blocks)
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
|
||||
@ -293,9 +292,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def _allocate_sequence(self, \
|
||||
seq: Sequence, \
|
||||
ref_count: int, \
|
||||
def _allocate_sequence(self,
|
||||
seq: Sequence,
|
||||
ref_count: int,
|
||||
is_encoder_decoder: bool = True) -> BlockTable:
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||
@ -328,10 +327,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# decoder prompt.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
block_table: BlockTable = \
|
||||
self._allocate_sequence(seq,
|
||||
seq_group.num_seqs(),
|
||||
is_encoder_decoder)
|
||||
block_table: BlockTable = self._allocate_sequence(
|
||||
seq, seq_group.num_seqs(), is_encoder_decoder)
|
||||
|
||||
# Assign the self-attention block tables for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
@ -368,6 +365,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# Compute a new hash for the block so that it can be shared by other
|
||||
# Sequences
|
||||
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
assert new_hash is not None, "Last block is not full."
|
||||
|
||||
# if new_hash is already in the cached table, then free last_block
|
||||
# and return the cached version
|
||||
@ -406,9 +404,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# content hash.
|
||||
if not self.enable_caching:
|
||||
return self.gpu_allocator.allocate()
|
||||
block_hash: Optional[int] = None
|
||||
if (self._is_last_block_full(seq)):
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
||||
len(seq.logical_token_blocks) - 1)
|
||||
|
||||
@ -553,18 +549,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# dict is efficient in lookup `if cpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.cpu_allocator,
|
||||
self.gpu_allocator,
|
||||
mapping)
|
||||
self.block_tables[seq.seq_id] = self._swap_block_table(
|
||||
self.block_tables[seq.seq_id], self.cpu_allocator,
|
||||
self.gpu_allocator, mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
self.cross_block_tables[request_id] = \
|
||||
self._swap_block_table(self.cross_block_tables[request_id],
|
||||
self.cpu_allocator,
|
||||
self.gpu_allocator,
|
||||
mapping)
|
||||
self.cross_block_tables[request_id] = self._swap_block_table(
|
||||
self.cross_block_tables[request_id], self.cpu_allocator,
|
||||
self.gpu_allocator, mapping)
|
||||
|
||||
return [(cpu_block.block_number, gpu_block.block_number)
|
||||
for cpu_block, gpu_block in mapping.items()]
|
||||
@ -580,18 +572,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# dict is efficient in lookup `if gpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.gpu_allocator,
|
||||
self.cpu_allocator,
|
||||
mapping)
|
||||
self.block_tables[seq.seq_id] = self._swap_block_table(
|
||||
self.block_tables[seq.seq_id], self.gpu_allocator,
|
||||
self.cpu_allocator, mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
self.cross_block_tables[request_id] = \
|
||||
self._swap_block_table(self.cross_block_tables[request_id],
|
||||
self.gpu_allocator,
|
||||
self.cpu_allocator,
|
||||
mapping)
|
||||
self.cross_block_tables[request_id] = self._swap_block_table(
|
||||
self.cross_block_tables[request_id], self.gpu_allocator,
|
||||
self.cpu_allocator, mapping)
|
||||
|
||||
return [(cpu_block.block_number, gpu_block.block_number)
|
||||
for cpu_block, gpu_block in mapping.items()]
|
||||
|
@ -41,46 +41,19 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
|
||||
# TODO: remove zero_point parameters once the configs given remove them
|
||||
|
||||
# Note on input/weight scales and zero_points
|
||||
#
|
||||
# When the scales have a single value, it is required that they be
|
||||
# on the CPU for 2 reasons,
|
||||
# 1. Performance:
|
||||
# When the scales (input_scale/weight_scales) have only a single
|
||||
# value, we perform a scalar broadcast of that value during the
|
||||
# quant/dequant operations. The "quant" and the "gemm+dequant"
|
||||
# kernels accept the Scalar by-value. These tensors are allocated
|
||||
# on the CPU in order to avoid the GPU-to-CPU copy when passing
|
||||
# by-value.
|
||||
#
|
||||
# 2. CUDA Graphs:
|
||||
# CUDA Graphs don't support GPU-to-CPU copy operations during
|
||||
# stream capture.
|
||||
#
|
||||
# TODO: zero-points are not supported yet. But we expect a similar
|
||||
# pattern.
|
||||
|
||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||
weight_scale_dim = sum(
|
||||
output_partition_sizes) if is_tensor_partitioned else 1
|
||||
weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda"
|
||||
|
||||
input_scale = Parameter(torch.empty(1,
|
||||
device="cpu",
|
||||
dtype=torch.float32),
|
||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
input_zero_point = Parameter(torch.empty(1,
|
||||
device="cpu",
|
||||
dtype=torch.int8),
|
||||
input_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
||||
device=weight_scale_device,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
weight_zero_point = Parameter(torch.empty(1,
|
||||
device="cpu",
|
||||
dtype=torch.int8),
|
||||
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
|
@ -269,15 +269,24 @@ class Sequence:
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# TODO This can produce incorrect hash when block size > prompt size
|
||||
|
||||
# Compute the number of tokens in the sequence
|
||||
def hash_of_block(self, logical_idx: int) -> Optional[int]:
|
||||
"""Return the hash of the block if it is full."""
|
||||
# TODO: The current hashing function is O(L^2). We should optimize
|
||||
# this in the future.
|
||||
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
|
||||
return hash((hashed_tokens, self.lora_int_id))
|
||||
assert logical_idx < len(self.logical_token_blocks), (
|
||||
f"logical_idx={logical_idx} is out of range for "
|
||||
f"logical_token_blocks={len(self.logical_token_blocks)}")
|
||||
block = self.logical_token_blocks[logical_idx]
|
||||
if block.block_hash is not None:
|
||||
return block.block_hash
|
||||
if not block.is_full():
|
||||
return None
|
||||
num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||
hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens)
|
||||
block_hash = hash((hashed_tokens, self.lora_int_id))
|
||||
# Cache the block hash for future use.
|
||||
block.block_hash = block_hash
|
||||
return block_hash
|
||||
|
||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||
return logical_idx * self.block_size + self.block_size
|
||||
@ -632,7 +641,7 @@ class SequenceGroupMetadata:
|
||||
state: Internal state tied to this sequence group.
|
||||
multi_modal_data: Multi modal data.
|
||||
encoder_seq_data: Optional sequence data for encoder prompt
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
model.
|
||||
cross_block_table: Optional cross-attention block table associated
|
||||
|
Reference in New Issue
Block a user