mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-31 06:14:38 +08:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			v0.4.3
			...
			fix-hashin
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1936d7bab0 | |||
| 996cf2de5c | |||
| 260d119e86 | |||
| a360ff80bb | 
| @ -33,20 +33,27 @@ | |||||||
| //
 | //
 | ||||||
| // This file is a modified excerpt of
 | // This file is a modified excerpt of
 | ||||||
| // include/cutlass/epilogue/fusion/visitor_load.hpp from
 | // include/cutlass/epilogue/fusion/visitor_load.hpp from
 | ||||||
| // https://github.com/NVIDIA/cutlass It's beem modified to support either
 | // https://github.com/NVIDIA/cutlass v3.5.0
 | ||||||
| // row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
 | // It has been modified to support either
 | ||||||
| // Important because this saves us a factor 4x on the number of kernels
 | // row/column or scalar broadcasting where the tensor being loaded from is
 | ||||||
| // compiled.
 | // always passed in via a device pointer. This lets one compiled kernel handle
 | ||||||
|  | // all cases of per-tensor or per-channel/per-token quantization.
 | ||||||
|  | //
 | ||||||
|  | // This interface also allows the scales to be passed in as tensors that
 | ||||||
|  | // consistently reside on the device, which avoids an issue with a previous
 | ||||||
|  | // implementation where scalars needed to be on the CPU since they
 | ||||||
|  | // were passed in via float values. This created a potential performance hazard
 | ||||||
|  | // if scales were initially on the device, and caused torch.compile graph
 | ||||||
|  | // breaks when moving scales to the CPU.
 | ||||||
| //
 | //
 | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
|  | // Turn off clang-format for the entire file to keep it close to upstream
 | ||||||
| // clang-format off
 | // clang-format off
 | ||||||
| 
 | 
 | ||||||
| #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" | #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" | ||||||
| #include "cute/tensor.hpp" | #include "cute/tensor.hpp" | ||||||
| 
 | 
 | ||||||
| // clang-format on
 |  | ||||||
| 
 |  | ||||||
| namespace cutlass::epilogue::threadblock { | namespace cutlass::epilogue::threadblock { | ||||||
| 
 | 
 | ||||||
| using namespace cute; | using namespace cute; | ||||||
| @ -59,9 +66,11 @@ template< | |||||||
| > | > | ||||||
| struct VisitorRowOrScalarBroadcast { | struct VisitorRowOrScalarBroadcast { | ||||||
| 
 | 
 | ||||||
|  |   // This struct has been modified to have a bool indicating that ptr_row is a 
 | ||||||
|  |   // scalar that must be broadcast.
 | ||||||
|   struct Arguments { |   struct Arguments { | ||||||
|     Element const* ptr_row = nullptr; |     Element const* ptr_row = nullptr; | ||||||
|     Element null_default = Element(0); |     bool row_broadcast = true; | ||||||
|     StrideMNL dRow = {}; |     StrideMNL dRow = {}; | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
| @ -125,25 +134,25 @@ struct VisitorRowOrScalarBroadcast { | |||||||
|       auto coord_v = filter(tC_cRow); |       auto coord_v = filter(tC_cRow); | ||||||
|       auto dst_v = filter(tC_rRow); |       auto dst_v = filter(tC_rRow); | ||||||
| 
 | 
 | ||||||
|       if (params_ptr->ptr_row) { |       if (params_ptr->row_broadcast) { | ||||||
|         // In this case we are loading from a row vector and broadcasting
 |         // In this case we are loading from a row vector and broadcasting
 | ||||||
|         CUTLASS_PRAGMA_UNROLL |         CUTLASS_PRAGMA_UNROLL | ||||||
|         for (int i = 0; i < size(src_v); ++i) { |         for (int i = 0; i < size(src_v); ++i) { | ||||||
|           bool guard = get<1>(coord_v(i)) < n; |           bool guard = get<1>(coord_v(i)) < n; | ||||||
|           cutlass::arch::global_load<VecType, sizeof(VecType)>(dst_v(i), (void const*)&src_v(i), guard); |           cutlass::arch::global_load<VecType, sizeof(VecType)>( | ||||||
|  |               dst_v(i), (void const*)&src_v(i), guard); | ||||||
|         } |         } | ||||||
|       } else { |       } else { | ||||||
|         // In this case we are loading from a scalar and broadcasting
 |         // In this case we are loading from a scalar and broadcasting
 | ||||||
|         VecType filled_vec; |         VecType filled_vec; | ||||||
|         CUTLASS_PRAGMA_UNROLL |         CUTLASS_PRAGMA_UNROLL | ||||||
|         for (int i = 0; i < VecLength; i++) { |         for (int i = 0; i < VecLength; i++) { | ||||||
|           reinterpret_cast<Element*>(&filled_vec)[i] = params_ptr->null_default; |           reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         CUTLASS_PRAGMA_UNROLL |         CUTLASS_PRAGMA_UNROLL | ||||||
|         for (int i = 0; i < size(src_v); ++i) { |         for (int i = 0; i < size(src_v); ++i) { | ||||||
|           if(get<1>(coord_v(i)) < n) |           if (get<1>(coord_v(i)) < n) { | ||||||
|           { |  | ||||||
|             dst_v(i) = filled_vec; |             dst_v(i) = filled_vec; | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
| @ -208,9 +217,11 @@ template< | |||||||
| > | > | ||||||
| struct VisitorColOrScalarBroadcast { | struct VisitorColOrScalarBroadcast { | ||||||
| 
 | 
 | ||||||
|  |   // This struct has been modified to have a bool indicating that ptr_col is a 
 | ||||||
|  |   // scalar that must be broadcast.
 | ||||||
|   struct Arguments { |   struct Arguments { | ||||||
|     Element const* ptr_col = nullptr; |     Element const* ptr_col = nullptr; | ||||||
|     Element null_default = Element(0); |     bool col_broadcast = true; | ||||||
|     StrideMNL dCol = {}; |     StrideMNL dCol = {}; | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
| @ -230,11 +241,6 @@ struct VisitorColOrScalarBroadcast { | |||||||
| 
 | 
 | ||||||
|   struct SharedStorage { }; |   struct SharedStorage { }; | ||||||
| 
 | 
 | ||||||
|   // Global load type
 |  | ||||||
|   static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; |  | ||||||
|   using VecType = uint_bit_t<cute::min(128, vec_bits)>; |  | ||||||
|   static int constexpr VecLength = sizeof(VecType) / sizeof(Element); |  | ||||||
| 
 |  | ||||||
|   CUTLASS_HOST_DEVICE |   CUTLASS_HOST_DEVICE | ||||||
|   VisitorColOrScalarBroadcast() { } |   VisitorColOrScalarBroadcast() { } | ||||||
| 
 | 
 | ||||||
| @ -267,7 +273,7 @@ struct VisitorColOrScalarBroadcast { | |||||||
|     int m; |     int m; | ||||||
| 
 | 
 | ||||||
|     // This function is modified from VisitorColBroadcast
 |     // This function is modified from VisitorColBroadcast
 | ||||||
|     CUTLASS_DEVICE void |     CUTLASS_DEVICE void  | ||||||
|     begin_epilogue() { |     begin_epilogue() { | ||||||
|       clear(tC_rCol); |       clear(tC_rCol); | ||||||
| 
 | 
 | ||||||
| @ -277,7 +283,7 @@ struct VisitorColOrScalarBroadcast { | |||||||
|         pred(i) = get<0>(tC_cCol(i)) < m; |         pred(i) = get<0>(tC_cCol(i)) < m; | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       if (params_ptr->ptr_col) { |       if (params_ptr->col_broadcast) { | ||||||
|         // In this case we are loading from a column vector and broadcasting
 |         // In this case we are loading from a column vector and broadcasting
 | ||||||
|         copy_if(pred, tC_gCol, tC_rCol); |         copy_if(pred, tC_gCol, tC_rCol); | ||||||
|       } else { |       } else { | ||||||
| @ -286,8 +292,8 @@ struct VisitorColOrScalarBroadcast { | |||||||
| 
 | 
 | ||||||
|         CUTLASS_PRAGMA_UNROLL |         CUTLASS_PRAGMA_UNROLL | ||||||
|         for (int i = 0; i < size(dst_v); ++i) { |         for (int i = 0; i < size(dst_v); ++i) { | ||||||
|           if(pred(i)){ |           if (pred(i)) { | ||||||
|              dst_v(i) = params_ptr->null_default; |             dst_v(i) = *(params_ptr->ptr_col); | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
							
								
								
									
										389
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,389 @@ | |||||||
|  | /*************************************************************************************************** | ||||||
|  |  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights | ||||||
|  |  *reserved. SPDX-License-Identifier: BSD-3-Clause | ||||||
|  |  * | ||||||
|  |  * Redistribution and use in source and binary forms, with or without | ||||||
|  |  * modification, are permitted provided that the following conditions are met: | ||||||
|  |  * | ||||||
|  |  * 1. Redistributions of source code must retain the above copyright notice, | ||||||
|  |  *this list of conditions and the following disclaimer. | ||||||
|  |  * | ||||||
|  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||||
|  |  * this list of conditions and the following disclaimer in the documentation | ||||||
|  |  * and/or other materials provided with the distribution. | ||||||
|  |  * | ||||||
|  |  * 3. Neither the name of the copyright holder nor the names of its | ||||||
|  |  * contributors may be used to endorse or promote products derived from | ||||||
|  |  * this software without specific prior written permission. | ||||||
|  |  * | ||||||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||
|  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||
|  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||
|  |  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||||
|  |  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||||
|  |  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||||
|  |  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||||
|  |  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||||
|  |  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||||
|  |  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||||
|  |  *POSSIBILITY OF SUCH DAMAGE. | ||||||
|  |  * | ||||||
|  |  **************************************************************************************************/ | ||||||
|  |  | ||||||
|  | // | ||||||
|  | // This file is a modified excerpt of | ||||||
|  | // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp | ||||||
|  | // from https://github.com/NVIDIA/cutlass v3.5.0 | ||||||
|  | // It has been modified to support either row/column or scalar broadcasting | ||||||
|  | // where the tensor being loaded from is always passed in via a device pointer. | ||||||
|  | // This lets one compiled kernel handle all cases of per-tensor or | ||||||
|  | // per-channel/per-token quantization. | ||||||
|  | // | ||||||
|  | // This interface also allows the scales to be passed in as tensors that | ||||||
|  | // consistently reside on the device, which avoids an issue with a previous | ||||||
|  | // implementation where scalars needed to be on the CPU since they | ||||||
|  | // were passed in via float values. This created a potential performance hazard | ||||||
|  | // if scales were initially on the device, and caused torch.compile graphs | ||||||
|  | // breaks when moving scales to the CPU. | ||||||
|  | // | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | // Turn off clang-format for the entire file to keep it close to upstream | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | #include "cutlass/cutlass.h" | ||||||
|  | #include "cutlass/arch/barrier.h" | ||||||
|  |  | ||||||
|  | #include "cute/tensor.hpp" | ||||||
|  | #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" | ||||||
|  |  | ||||||
|  | namespace cutlass::epilogue::fusion { | ||||||
|  |  | ||||||
|  | using namespace cute; | ||||||
|  | using namespace detail; | ||||||
|  |  | ||||||
|  | // Row vector broadcast | ||||||
|  | template< | ||||||
|  |   // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least | ||||||
|  |   // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races | ||||||
|  |   int Stages, | ||||||
|  |   class CtaTileShapeMNK, | ||||||
|  |   class Element, | ||||||
|  |   class StrideMNL = Stride<_0,_1,_0>, | ||||||
|  |   int Alignment = 128 / sizeof_bits_v<Element> | ||||||
|  | > | ||||||
|  | struct Sm90RowOrScalarBroadcast { | ||||||
|  |   static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); | ||||||
|  |   static_assert( | ||||||
|  |     (cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias | ||||||
|  |     (cute::is_same_v<StrideMNL, Stride<_0,_1,int>>));  // batched row vector broadcast | ||||||
|  |  | ||||||
|  |   // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem | ||||||
|  |   struct SharedStorage { | ||||||
|  |     alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   // This struct has been modified to have a bool indicating that ptr_row is a  | ||||||
|  |   // scalar that must be broadcast, instead of containing a scalar that is  | ||||||
|  |   // valid if ptr_row is null. | ||||||
|  |   struct Arguments { | ||||||
|  |     Element const* ptr_row = nullptr; | ||||||
|  |     bool row_broadcast = true; | ||||||
|  |     StrideMNL dRow = {}; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   using Params = Arguments; | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static constexpr Params | ||||||
|  |   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||||
|  |     return args; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static size_t | ||||||
|  |   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static cutlass::Status | ||||||
|  |   initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, | ||||||
|  |     CudaHostAdapter* cuda_adapter = nullptr) { | ||||||
|  |     return cutlass::Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_HOST_DEVICE | ||||||
|  |   Sm90RowOrScalarBroadcast() { } | ||||||
|  |  | ||||||
|  |   CUTLASS_HOST_DEVICE | ||||||
|  |   Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||||
|  |       : params(params), | ||||||
|  |         smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { } | ||||||
|  |  | ||||||
|  |   Params params; | ||||||
|  |   Element* smem_row; | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_producer_load_needed() const { | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_C_load_needed() const { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_zero() const { | ||||||
|  |     return (!params.row_broadcast && *(params.ptr_row) == Element(0)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <int EpiTiles, class GTensor, class STensor> | ||||||
|  |   struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { | ||||||
|  |     CUTLASS_DEVICE | ||||||
|  |     ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) | ||||||
|  |       : gRow(cute::forward<GTensor>(gRow)), | ||||||
|  |         sRow(cute::forward<STensor>(sRow)), | ||||||
|  |         params(params) {} | ||||||
|  |  | ||||||
|  |     GTensor gRow;                                                                                 // (CTA_M,CTA_N) | ||||||
|  |     STensor sRow;                                                                                 // (CTA_M,CTA_N,PIPE) | ||||||
|  |     Params const& params; | ||||||
|  |  | ||||||
|  |     CUTLASS_DEVICE void | ||||||
|  |     begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { | ||||||
|  |       if (params.ptr_row == nullptr) { | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       if (issue_tma_load) { | ||||||
|  |         // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size | ||||||
|  |         constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8; | ||||||
|  |         cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); | ||||||
|  |         // Issue the TMA bulk copy | ||||||
|  |         auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr); | ||||||
|  |         // Filter so we don't issue redundant copies over stride-0 modes | ||||||
|  |         int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; | ||||||
|  |         copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   template <class... Args> | ||||||
|  |   CUTLASS_DEVICE auto | ||||||
|  |   get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { | ||||||
|  |  | ||||||
|  |     auto [M, N, K, L] = args.problem_shape_mnkl; | ||||||
|  |     auto [m, n, k, l] = args.tile_coord_mnkl; | ||||||
|  |     Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); | ||||||
|  |     Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l));            // (CTA_M,CTA_N) | ||||||
|  |     Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE) | ||||||
|  |                     make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), | ||||||
|  |                     make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); | ||||||
|  |  | ||||||
|  |     constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; | ||||||
|  |     return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>( | ||||||
|  |       cute::move(gRow), cute::move(sRow), params); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <int EpiTiles, class RTensor, class STensor> | ||||||
|  |   struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { | ||||||
|  |     CUTLASS_DEVICE | ||||||
|  |     ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) | ||||||
|  |       : tCrRow(cute::forward<RTensor>(tCrRow)), | ||||||
|  |         tCsRow(cute::forward<STensor>(tCsRow)), | ||||||
|  |         params(params) {} | ||||||
|  |  | ||||||
|  |     RTensor tCrRow;                                                               // (CPY,CPY_M,CPY_N) | ||||||
|  |     STensor tCsRow;                                                               // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) | ||||||
|  |     Params const& params; | ||||||
|  |  | ||||||
|  |     CUTLASS_DEVICE void | ||||||
|  |     previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { | ||||||
|  |       if (!params.row_broadcast) { | ||||||
|  |         fill(tCrRow, *(params.ptr_row)); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       if (epi_m == 0) { // Assumes M-major subtile loop | ||||||
|  |         // Filter so we don't issue redundant copies over stride-0 modes | ||||||
|  |         // (only works if 0-strides are in same location, which is by construction) | ||||||
|  |         int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; | ||||||
|  |         copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     template <typename ElementAccumulator, int FragmentSize> | ||||||
|  |     CUTLASS_DEVICE Array<Element, FragmentSize> | ||||||
|  |     visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { | ||||||
|  |       Array<Element, FragmentSize> frg_row; | ||||||
|  |  | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < FragmentSize; ++i) { | ||||||
|  |         frg_row[i] = tCrRow(epi_v * FragmentSize + i); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       return frg_row; | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   template < | ||||||
|  |     bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy | ||||||
|  |     class... Args | ||||||
|  |   > | ||||||
|  |   CUTLASS_DEVICE auto | ||||||
|  |   get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { | ||||||
|  |  | ||||||
|  |     Tensor sRow = make_tensor(make_smem_ptr(smem_row),                                            // (CTA_M,CTA_N,PIPE) | ||||||
|  |                     make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), | ||||||
|  |                     make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); | ||||||
|  |     Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>(                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) | ||||||
|  |                       sRow, args.epi_tile, args.tiled_copy, args.thread_idx); | ||||||
|  |     Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow));                                           // (CPY,CPY_M,CPY_N) | ||||||
|  |  | ||||||
|  |     constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; | ||||||
|  |     return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>( | ||||||
|  |       cute::move(tCrRow), cute::move(tCsRow), params); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | ///////////////////////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | // Column vector broadcast | ||||||
|  | template< | ||||||
|  |   int Stages, | ||||||
|  |   class CtaTileShapeMNK, | ||||||
|  |   class Element, | ||||||
|  |   class StrideMNL = Stride<_1,_0,_0>, | ||||||
|  |   int Alignment = 128 / sizeof_bits_v<Element> | ||||||
|  | > | ||||||
|  | struct Sm90ColOrScalarBroadcast { | ||||||
|  |   static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); | ||||||
|  |   static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet"); | ||||||
|  |   static_assert( | ||||||
|  |     (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias | ||||||
|  |     (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>));  // batched col vector broadcast, e.g. batched per-row bias | ||||||
|  |  | ||||||
|  |   // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem | ||||||
|  |   struct SharedStorage { }; | ||||||
|  |  | ||||||
|  |   // This struct has been modified to have a bool indicating that ptr_col is a  | ||||||
|  |   // scalar that must be broadcast, instead of containing a scalar that is  | ||||||
|  |   // valid if ptr_col is null. | ||||||
|  |   struct Arguments { | ||||||
|  |     Element const* ptr_col = nullptr; | ||||||
|  |     bool col_broadcast = true; | ||||||
|  |     StrideMNL dCol = {}; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   using Params = Arguments; | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static constexpr Params | ||||||
|  |   to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { | ||||||
|  |     return args; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static size_t | ||||||
|  |   get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <class ProblemShape> | ||||||
|  |   static cutlass::Status | ||||||
|  |   initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, | ||||||
|  |     CudaHostAdapter* cuda_adapter = nullptr) { | ||||||
|  |     return cutlass::Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_producer_load_needed() const { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_C_load_needed() const { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE bool | ||||||
|  |   is_zero() const { | ||||||
|  |     return (!params.col_broadcast && *(params.ptr_col) == Element(0)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_HOST_DEVICE | ||||||
|  |   Sm90ColOrScalarBroadcast() { } | ||||||
|  |  | ||||||
|  |   CUTLASS_HOST_DEVICE | ||||||
|  |   Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) | ||||||
|  |       : params(params) { } | ||||||
|  |  | ||||||
|  |   Params params; | ||||||
|  |  | ||||||
|  |   template <class... Args> | ||||||
|  |   CUTLASS_DEVICE auto | ||||||
|  |   get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) { | ||||||
|  |     return EmptyProducerLoadCallbacks{}; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template<class GTensor, class RTensor> | ||||||
|  |   struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { | ||||||
|  |     CUTLASS_DEVICE | ||||||
|  |     ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) | ||||||
|  |       : tCgCol(cute::forward<GTensor>(tCgCol)), | ||||||
|  |         tCrCol(cute::forward<RTensor>(tCrCol)), | ||||||
|  |         params(params) {} | ||||||
|  |  | ||||||
|  |     GTensor tCgCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||||
|  |     RTensor tCrCol;                                                                    // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||||
|  |     Params const& params; | ||||||
|  |  | ||||||
|  |     CUTLASS_DEVICE void | ||||||
|  |     begin() { | ||||||
|  |       if (!params.col_broadcast) { | ||||||
|  |         fill(tCrCol, *(params.ptr_col)); | ||||||
|  |         return; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // Filter so we don't issue redundant copies over stride-0 modes | ||||||
|  |       // (only works if 0-strides are in same location, which is by construction) | ||||||
|  |       copy_aligned(filter(tCgCol), filter(tCrCol)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     template <typename ElementAccumulator, int FragmentSize> | ||||||
|  |     CUTLASS_DEVICE Array<Element, FragmentSize> | ||||||
|  |     visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { | ||||||
|  |       Array<Element, FragmentSize> frg_col; | ||||||
|  |       Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); | ||||||
|  |  | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < FragmentSize; ++i) { | ||||||
|  |         frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       return frg_col; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   template < | ||||||
|  |     bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy | ||||||
|  |     class... Args | ||||||
|  |   > | ||||||
|  |   CUTLASS_DEVICE auto | ||||||
|  |   get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { | ||||||
|  |  | ||||||
|  |     auto [M, N, K, L] = args.problem_shape_mnkl; | ||||||
|  |     Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); | ||||||
|  |     Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>(                         // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||||
|  |       mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); | ||||||
|  |     Tensor tCrCol = make_tensor_like(tCgCol);                                          // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) | ||||||
|  |  | ||||||
|  |     return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>( | ||||||
|  |       cute::move(tCgCol), cute::move(tCrCol), params); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | } | ||||||
| @ -22,7 +22,7 @@ | |||||||
| #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" | #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" | ||||||
| #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" | #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" | ||||||
|  |  | ||||||
| #include "cutlass_visitor_2x_broadcast_epilogue.hpp" | #include "broadcast_load_epilogue_c2x.hpp" | ||||||
| #include "common.hpp" | #include "common.hpp" | ||||||
| // clang-format on | // clang-format on | ||||||
|  |  | ||||||
| @ -145,17 +145,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, | |||||||
|   auto a_scales_ptr = a_scales.data_ptr<float>(); |   auto a_scales_ptr = a_scales.data_ptr<float>(); | ||||||
|   auto b_scales_ptr = b_scales.data_ptr<float>(); |   auto b_scales_ptr = b_scales.data_ptr<float>(); | ||||||
|  |  | ||||||
|   // If A and B are quantized per-tensor, then these scale tensors are scalars, |  | ||||||
|   // and they are passed in via the second argument. |  | ||||||
|   using ScaleAArgs = typename Gemm::ScaleA::Arguments; |   using ScaleAArgs = typename Gemm::ScaleA::Arguments; | ||||||
|   ScaleAArgs a_args = a_scales.numel() == 1 |  | ||||||
|                           ? ScaleAArgs{nullptr, a_scales.item<float>(), {}} |  | ||||||
|                           : ScaleAArgs{a_scales.data_ptr<float>(), {}, {}}; |  | ||||||
|  |  | ||||||
|   using ScaleBArgs = typename Gemm::ScaleB::Arguments; |   using ScaleBArgs = typename Gemm::ScaleB::Arguments; | ||||||
|   ScaleBArgs b_args = b_scales.numel() == 1 |  | ||||||
|                           ? ScaleBArgs{nullptr, b_scales.item<float>(), {}} |   ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}}; | ||||||
|                           : ScaleBArgs{b_scales.data_ptr<float>(), {}, {}}; |   ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}}; | ||||||
|  |  | ||||||
|   typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; |   typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,11 +18,14 @@ | |||||||
| #include "cute/atom/mma_atom.hpp" | #include "cute/atom/mma_atom.hpp" | ||||||
| #include "cutlass/numeric_types.h" | #include "cutlass/numeric_types.h" | ||||||
|  |  | ||||||
|  | #include "cutlass/util/device_memory.h" | ||||||
|  |  | ||||||
| #include "cutlass/gemm/device/gemm_universal_adapter.h" | #include "cutlass/gemm/device/gemm_universal_adapter.h" | ||||||
| #include "cutlass/gemm/kernel/gemm_universal.hpp" | #include "cutlass/gemm/kernel/gemm_universal.hpp" | ||||||
| #include "cutlass/epilogue/collective/collective_builder.hpp" | #include "cutlass/epilogue/collective/collective_builder.hpp" | ||||||
| #include "cutlass/gemm/collective/collective_builder.hpp" | #include "cutlass/gemm/collective/collective_builder.hpp" | ||||||
|  |  | ||||||
|  | #include "broadcast_load_epilogue_c3x.hpp" | ||||||
| #include "common.hpp" | #include "common.hpp" | ||||||
| // clang-format on | // clang-format on | ||||||
|  |  | ||||||
| @ -65,7 +68,7 @@ struct cutlass_3x_gemm { | |||||||
|  |  | ||||||
|   using Accum = cutlass::epilogue::fusion::Sm90AccFetch; |   using Accum = cutlass::epilogue::fusion::Sm90AccFetch; | ||||||
|  |  | ||||||
|   using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< |   using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< | ||||||
|       0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, |       0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, | ||||||
|       Stride<Int<1>, Int<0>, Int<0>>>; |       Stride<Int<1>, Int<0>, Int<0>>>; | ||||||
|  |  | ||||||
| @ -73,7 +76,7 @@ struct cutlass_3x_gemm { | |||||||
|       cutlass::epilogue::collective::detail::RowBroadcastDescriptor< |       cutlass::epilogue::collective::detail::RowBroadcastDescriptor< | ||||||
|           EpilogueDescriptor, float>; |           EpilogueDescriptor, float>; | ||||||
|  |  | ||||||
|   using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< |   using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< | ||||||
|       ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, |       ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, | ||||||
|       typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>; |       typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>; | ||||||
|  |  | ||||||
| @ -166,13 +169,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, | |||||||
|  |  | ||||||
|   using ScaleA_Args = typename Gemm::ScaleA::Arguments; |   using ScaleA_Args = typename Gemm::ScaleA::Arguments; | ||||||
|   using ScaleB_Args = typename Gemm::ScaleB::Arguments; |   using ScaleB_Args = typename Gemm::ScaleB::Arguments; | ||||||
|   ScaleA_Args a_args = a_scales.numel() == 1 |  | ||||||
|                            ? ScaleA_Args{nullptr, a_scales.item<float>(), {}} |  | ||||||
|                            : ScaleA_Args{a_scales.data_ptr<float>(), {}, {}}; |  | ||||||
|  |  | ||||||
|   ScaleB_Args b_args = b_scales.numel() == 1 |   ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}}; | ||||||
|                            ? ScaleB_Args{nullptr, b_scales.item<float>(), {}} |   ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}}; | ||||||
|                            : ScaleB_Args{b_scales.data_ptr<float>(), {}, {}}; |  | ||||||
|  |  | ||||||
|   args.epilogue.thread = {a_args, {b_args}}; |   args.epilogue.thread = {a_args, {b_args}}; | ||||||
|  |  | ||||||
| @ -182,10 +181,11 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, | |||||||
|   CUTLASS_CHECK(gemm_op.can_implement(args)); |   CUTLASS_CHECK(gemm_op.can_implement(args)); | ||||||
|  |  | ||||||
|   size_t workspace_size = gemm_op.get_workspace_size(args); |   size_t workspace_size = gemm_op.get_workspace_size(args); | ||||||
|   TORCH_CHECK(workspace_size == 0); |   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); | ||||||
|  |  | ||||||
|   auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); |   auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); | ||||||
|   cutlass::Status status = gemm_op.run(args, stream); |  | ||||||
|  |   cutlass::Status status = gemm_op.run(args, workspace.get(), stream); | ||||||
|   CUTLASS_CHECK(status); |   CUTLASS_CHECK(status); | ||||||
| } | } | ||||||
| }  // namespace | }  // namespace | ||||||
|  | |||||||
| @ -59,7 +59,7 @@ exclude = [ | |||||||
| ] | ] | ||||||
|  |  | ||||||
| [tool.codespell] | [tool.codespell] | ||||||
| ignore-words-list = "dout, te, indicies" | ignore-words-list = "dout, te, indicies, subtile" | ||||||
| skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" | skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" | ||||||
|  |  | ||||||
| [tool.isort] | [tool.isort] | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								setup.py
									
									
									
									
									
								
							| @ -187,19 +187,22 @@ class cmake_build_ext(build_ext): | |||||||
|         if not os.path.exists(self.build_temp): |         if not os.path.exists(self.build_temp): | ||||||
|             os.makedirs(self.build_temp) |             os.makedirs(self.build_temp) | ||||||
|  |  | ||||||
|  |         targets = [] | ||||||
|         # Build all the extensions |         # Build all the extensions | ||||||
|         for ext in self.extensions: |         for ext in self.extensions: | ||||||
|             self.configure(ext) |             self.configure(ext) | ||||||
|  |             targets.append(remove_prefix(ext.name, "vllm.")) | ||||||
|  |  | ||||||
|             ext_target_name = remove_prefix(ext.name, "vllm.") |         num_jobs, _ = self.compute_num_jobs() | ||||||
|             num_jobs, _ = self.compute_num_jobs() |  | ||||||
|  |  | ||||||
|             build_args = [ |         build_args = [ | ||||||
|                 '--build', '.', '--target', ext_target_name, '-j', |             "--build", | ||||||
|                 str(num_jobs) |             ".", | ||||||
|             ] |             f"-j={num_jobs}", | ||||||
|  |             *[f"--target={name}" for name in targets], | ||||||
|  |         ] | ||||||
|  |  | ||||||
|             subprocess.check_call(['cmake', *build_args], cwd=self.build_temp) |         subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _is_cuda() -> bool: | def _is_cuda() -> bool: | ||||||
|  | |||||||
| @ -207,14 +207,21 @@ class CutlassLayer(torch.nn.Module): | |||||||
|                                         self.out_dtype) |                                         self.out_dtype) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_cutlass_cuda_graph(): | @pytest.mark.parametrize("per_act_token", [True, False]) | ||||||
|  | @pytest.mark.parametrize("per_out_ch", [True, False]) | ||||||
|  | def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): | ||||||
|     m, n, k = 512, 512, 512 |     m, n, k = 512, 512, 512 | ||||||
|  |  | ||||||
|     a = to_int8(torch.randn((m, k), device="cuda")) |     a = to_int8(torch.randn((m, k), device="cuda")) | ||||||
|     b = to_int8(torch.randn((n, k), device="cuda").t()) |     b = to_int8(torch.randn((n, k), device="cuda").t()) | ||||||
|  |  | ||||||
|     scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10) |     m_a_scales = m if per_act_token else 1 | ||||||
|     scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10) |     n_b_scales = n if per_out_ch else 1 | ||||||
|  |  | ||||||
|  |     scale_a = (torch.randn( | ||||||
|  |         (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) | ||||||
|  |     scale_b = (torch.randn( | ||||||
|  |         (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) | ||||||
|  |  | ||||||
|     # Construct a trivial model with a single layer that calls a CUTLASS kernel |     # Construct a trivial model with a single layer that calls a CUTLASS kernel | ||||||
|     model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) |     model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| """Token blocks.""" | """Token blocks.""" | ||||||
| from typing import List | from typing import List, Optional | ||||||
|  |  | ||||||
| from vllm.utils import Device | from vllm.utils import Device | ||||||
|  |  | ||||||
| @ -25,6 +25,7 @@ class LogicalTokenBlock: | |||||||
|  |  | ||||||
|         self.token_ids = [_BLANK_TOKEN_ID] * block_size |         self.token_ids = [_BLANK_TOKEN_ID] * block_size | ||||||
|         self.num_tokens = 0 |         self.num_tokens = 0 | ||||||
|  |         self.block_hash: Optional[int] = None | ||||||
|  |  | ||||||
|     def is_empty(self) -> bool: |     def is_empty(self) -> bool: | ||||||
|         return self.num_tokens == 0 |         return self.num_tokens == 0 | ||||||
|  | |||||||
| @ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         self.cross_block_tables: Dict[str, BlockTable] = {} |         self.cross_block_tables: Dict[str, BlockTable] = {} | ||||||
|  |  | ||||||
|     def _get_seq_num_required_blocks(self, seq: Sequence) -> int: |     def _get_seq_num_required_blocks(self, seq: Sequence) -> int: | ||||||
|         return 0 if seq is None \ |         return 0 if seq is None else len(seq.logical_token_blocks) | ||||||
|             else len(seq.logical_token_blocks) |  | ||||||
|  |  | ||||||
|     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: |     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: | ||||||
|         # FIXME(woosuk): Here we assume that all sequences in the group share |         # FIXME(woosuk): Here we assume that all sequences in the group share | ||||||
| @ -275,8 +274,8 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|             seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) |             seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) | ||||||
|         cross_num_required_blocks = self._get_seq_num_required_blocks( |         cross_num_required_blocks = self._get_seq_num_required_blocks( | ||||||
|             seq_group.get_encoder_seq()) |             seq_group.get_encoder_seq()) | ||||||
|         num_required_blocks = self_num_required_blocks + \ |         num_required_blocks = (self_num_required_blocks + | ||||||
|                               cross_num_required_blocks |                                cross_num_required_blocks) | ||||||
|  |  | ||||||
|         if self.block_sliding_window is not None: |         if self.block_sliding_window is not None: | ||||||
|  |  | ||||||
| @ -293,9 +292,9 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         else: |         else: | ||||||
|             return AllocStatus.LATER |             return AllocStatus.LATER | ||||||
|  |  | ||||||
|     def _allocate_sequence(self, \ |     def _allocate_sequence(self, | ||||||
|                            seq: Sequence, \ |                            seq: Sequence, | ||||||
|                            ref_count: int, \ |                            ref_count: int, | ||||||
|                            is_encoder_decoder: bool = True) -> BlockTable: |                            is_encoder_decoder: bool = True) -> BlockTable: | ||||||
|         # Allocate new physical token blocks that will store the prompt tokens. |         # Allocate new physical token blocks that will store the prompt tokens. | ||||||
|         num_prompt_blocks = len(seq.logical_token_blocks) |         num_prompt_blocks = len(seq.logical_token_blocks) | ||||||
| @ -328,10 +327,8 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         # NOTE: Here we assume that all sequences in the group have the same |         # NOTE: Here we assume that all sequences in the group have the same | ||||||
|         # decoder prompt. |         # decoder prompt. | ||||||
|         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] |         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] | ||||||
|         block_table: BlockTable = \ |         block_table: BlockTable = self._allocate_sequence( | ||||||
|             self._allocate_sequence(seq, |             seq, seq_group.num_seqs(), is_encoder_decoder) | ||||||
|                                     seq_group.num_seqs(), |  | ||||||
|                                     is_encoder_decoder) |  | ||||||
|  |  | ||||||
|         # Assign the self-attention block tables for each sequence. |         # Assign the self-attention block tables for each sequence. | ||||||
|         for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): |         for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): | ||||||
| @ -368,6 +365,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         # Compute a new hash for the block so that it can be shared by other |         # Compute a new hash for the block so that it can be shared by other | ||||||
|         # Sequences |         # Sequences | ||||||
|         new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) |         new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) | ||||||
|  |         assert new_hash is not None, "Last block is not full." | ||||||
|  |  | ||||||
|         # if new_hash is already in the cached table, then free last_block |         # if new_hash is already in the cached table, then free last_block | ||||||
|         # and return the cached version |         # and return the cached version | ||||||
| @ -406,9 +404,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         # content hash. |         # content hash. | ||||||
|         if not self.enable_caching: |         if not self.enable_caching: | ||||||
|             return self.gpu_allocator.allocate() |             return self.gpu_allocator.allocate() | ||||||
|         block_hash: Optional[int] = None |         block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) | ||||||
|         if (self._is_last_block_full(seq)): |  | ||||||
|             block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) |  | ||||||
|         num_hashed_tokens = seq.num_hashed_tokens_of_block( |         num_hashed_tokens = seq.num_hashed_tokens_of_block( | ||||||
|             len(seq.logical_token_blocks) - 1) |             len(seq.logical_token_blocks) - 1) | ||||||
|  |  | ||||||
| @ -553,18 +549,14 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         # dict is efficient in lookup `if cpu_block in mapping` |         # dict is efficient in lookup `if cpu_block in mapping` | ||||||
|         mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} |         mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} | ||||||
|         for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): |         for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): | ||||||
|             self.block_tables[seq.seq_id] = \ |             self.block_tables[seq.seq_id] = self._swap_block_table( | ||||||
|                 self._swap_block_table(self.block_tables[seq.seq_id], |                 self.block_tables[seq.seq_id], self.cpu_allocator, | ||||||
|                                        self.cpu_allocator, |                 self.gpu_allocator, mapping) | ||||||
|                                        self.gpu_allocator, |  | ||||||
|                                        mapping) |  | ||||||
|  |  | ||||||
|         if seq_group.is_encoder_decoder(): |         if seq_group.is_encoder_decoder(): | ||||||
|             self.cross_block_tables[request_id] = \ |             self.cross_block_tables[request_id] = self._swap_block_table( | ||||||
|                 self._swap_block_table(self.cross_block_tables[request_id], |                 self.cross_block_tables[request_id], self.cpu_allocator, | ||||||
|                                        self.cpu_allocator, |                 self.gpu_allocator, mapping) | ||||||
|                                        self.gpu_allocator, |  | ||||||
|                                        mapping) |  | ||||||
|  |  | ||||||
|         return [(cpu_block.block_number, gpu_block.block_number) |         return [(cpu_block.block_number, gpu_block.block_number) | ||||||
|                 for cpu_block, gpu_block in mapping.items()] |                 for cpu_block, gpu_block in mapping.items()] | ||||||
| @ -580,18 +572,14 @@ class BlockSpaceManagerV1(BlockSpaceManager): | |||||||
|         # dict is efficient in lookup `if gpu_block in mapping` |         # dict is efficient in lookup `if gpu_block in mapping` | ||||||
|         mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} |         mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} | ||||||
|         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): |         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||||||
|             self.block_tables[seq.seq_id] = \ |             self.block_tables[seq.seq_id] = self._swap_block_table( | ||||||
|                 self._swap_block_table(self.block_tables[seq.seq_id], |                 self.block_tables[seq.seq_id], self.gpu_allocator, | ||||||
|                                        self.gpu_allocator, |                 self.cpu_allocator, mapping) | ||||||
|                                        self.cpu_allocator, |  | ||||||
|                                        mapping) |  | ||||||
|  |  | ||||||
|         if seq_group.is_encoder_decoder(): |         if seq_group.is_encoder_decoder(): | ||||||
|             self.cross_block_tables[request_id] = \ |             self.cross_block_tables[request_id] = self._swap_block_table( | ||||||
|                 self._swap_block_table(self.cross_block_tables[request_id], |                 self.cross_block_tables[request_id], self.gpu_allocator, | ||||||
|                                        self.gpu_allocator, |                 self.cpu_allocator, mapping) | ||||||
|                                        self.cpu_allocator, |  | ||||||
|                                        mapping) |  | ||||||
|  |  | ||||||
|         return [(cpu_block.block_number, gpu_block.block_number) |         return [(cpu_block.block_number, gpu_block.block_number) | ||||||
|                 for cpu_block, gpu_block in mapping.items()] |                 for cpu_block, gpu_block in mapping.items()] | ||||||
|  | |||||||
| @ -41,46 +41,19 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): | |||||||
|  |  | ||||||
|         # TODO: remove zero_point parameters once the configs given remove them |         # TODO: remove zero_point parameters once the configs given remove them | ||||||
|  |  | ||||||
|         # Note on input/weight scales and zero_points |  | ||||||
|         # |  | ||||||
|         # When the scales have a single value, it is required that they be |  | ||||||
|         # on the CPU for 2 reasons, |  | ||||||
|         # 1. Performance: |  | ||||||
|         #   When the scales (input_scale/weight_scales) have only a single |  | ||||||
|         #   value, we perform a scalar broadcast of that value during the |  | ||||||
|         #   quant/dequant operations. The "quant" and the "gemm+dequant" |  | ||||||
|         #   kernels accept the Scalar by-value. These tensors are allocated |  | ||||||
|         #   on the CPU in order to avoid the GPU-to-CPU copy when passing |  | ||||||
|         #   by-value. |  | ||||||
|         # |  | ||||||
|         # 2. CUDA Graphs: |  | ||||||
|         #   CUDA Graphs don't support GPU-to-CPU copy operations during |  | ||||||
|         #   stream capture. |  | ||||||
|         # |  | ||||||
|         # TODO: zero-points are not supported yet. But we expect a similar |  | ||||||
|         # pattern. |  | ||||||
|  |  | ||||||
|         is_tensor_partitioned = len(output_partition_sizes) != 1 |         is_tensor_partitioned = len(output_partition_sizes) != 1 | ||||||
|         weight_scale_dim = sum( |         weight_scale_dim = sum( | ||||||
|             output_partition_sizes) if is_tensor_partitioned else 1 |             output_partition_sizes) if is_tensor_partitioned else 1 | ||||||
|         weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" |  | ||||||
|  |  | ||||||
|         input_scale = Parameter(torch.empty(1, |         input_scale = Parameter(torch.empty(1, dtype=torch.float32), | ||||||
|                                             device="cpu", |  | ||||||
|                                             dtype=torch.float32), |  | ||||||
|                                 requires_grad=False) |                                 requires_grad=False) | ||||||
|         input_zero_point = Parameter(torch.empty(1, |         input_zero_point = Parameter(torch.empty(1, dtype=torch.int8), | ||||||
|                                                  device="cpu", |  | ||||||
|                                                  dtype=torch.int8), |  | ||||||
|                                      requires_grad=False) |                                      requires_grad=False) | ||||||
|  |  | ||||||
|         weight_scale = Parameter(torch.empty(weight_scale_dim, |         weight_scale = Parameter(torch.empty(weight_scale_dim, | ||||||
|                                              device=weight_scale_device, |  | ||||||
|                                              dtype=torch.float32), |                                              dtype=torch.float32), | ||||||
|                                  requires_grad=False) |                                  requires_grad=False) | ||||||
|         weight_zero_point = Parameter(torch.empty(1, |         weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8), | ||||||
|                                                   device="cpu", |  | ||||||
|                                                   dtype=torch.int8), |  | ||||||
|                                       requires_grad=False) |                                       requires_grad=False) | ||||||
|  |  | ||||||
|         weight = Parameter(torch.empty(sum(output_partition_sizes), |         weight = Parameter(torch.empty(sum(output_partition_sizes), | ||||||
|  | |||||||
| @ -269,15 +269,24 @@ class Sequence: | |||||||
|         return self.output_text[:-buffer_length] if truncate else ( |         return self.output_text[:-buffer_length] if truncate else ( | ||||||
|             self.output_text) |             self.output_text) | ||||||
|  |  | ||||||
|     def hash_of_block(self, logical_idx: int) -> int: |     def hash_of_block(self, logical_idx: int) -> Optional[int]: | ||||||
|         # TODO This can produce incorrect hash when block size > prompt size |         """Return the hash of the block if it is full.""" | ||||||
|  |  | ||||||
|         # Compute the number of tokens in the sequence |  | ||||||
|         # TODO: The current hashing function is O(L^2). We should optimize |         # TODO: The current hashing function is O(L^2). We should optimize | ||||||
|         # this in the future. |         # this in the future. | ||||||
|         num_tokens = self.num_hashed_tokens_of_block(logical_idx) |         assert logical_idx < len(self.logical_token_blocks), ( | ||||||
|         hashed_tokens = self.data.get_prefix_token_ids(num_tokens) |             f"logical_idx={logical_idx} is out of range for " | ||||||
|         return hash((hashed_tokens, self.lora_int_id)) |             f"logical_token_blocks={len(self.logical_token_blocks)}") | ||||||
|  |         block = self.logical_token_blocks[logical_idx] | ||||||
|  |         if block.block_hash is not None: | ||||||
|  |             return block.block_hash | ||||||
|  |         if not block.is_full(): | ||||||
|  |             return None | ||||||
|  |         num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx) | ||||||
|  |         hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens) | ||||||
|  |         block_hash = hash((hashed_tokens, self.lora_int_id)) | ||||||
|  |         # Cache the block hash for future use. | ||||||
|  |         block.block_hash = block_hash | ||||||
|  |         return block_hash | ||||||
|  |  | ||||||
|     def num_hashed_tokens_of_block(self, logical_idx: int): |     def num_hashed_tokens_of_block(self, logical_idx: int): | ||||||
|         return logical_idx * self.block_size + self.block_size |         return logical_idx * self.block_size + self.block_size | ||||||
| @ -632,7 +641,7 @@ class SequenceGroupMetadata: | |||||||
|         state: Internal state tied to this sequence group. |         state: Internal state tied to this sequence group. | ||||||
|         multi_modal_data: Multi modal data. |         multi_modal_data: Multi modal data. | ||||||
|         encoder_seq_data: Optional sequence data for encoder prompt |         encoder_seq_data: Optional sequence data for encoder prompt | ||||||
|                           (SequenceGroup.encoder_seq). Should be None  |                           (SequenceGroup.encoder_seq). Should be None | ||||||
|                           unless you are working with an encoder/decoder |                           unless you are working with an encoder/decoder | ||||||
|                           model. |                           model. | ||||||
|         cross_block_table: Optional cross-attention block table associated |         cross_block_table: Optional cross-attention block table associated | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	