mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@ -110,7 +110,7 @@ def benchmark_decode(
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
|
@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
use_tensor_cores=True)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
|
||||
use_tensor_cores = True
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
|
@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
|
||||
# Baseline Decode
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
|
||||
workspace_buffer, kv_layout, use_tensor_cores=True)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
|
@ -42,7 +42,6 @@ if TYPE_CHECKING:
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
||||
@ -465,11 +464,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
|
||||
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
|
||||
|
||||
# If set, vllm will force flashinfer to use tensor cores;
|
||||
# otherwise will use heuristic based on model architecture.
|
||||
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
|
||||
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
|
||||
|
||||
# Pipeline stage partition strategy
|
||||
"VLLM_PP_LAYER_PARTITION":
|
||||
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
||||
@ -1221,7 +1215,6 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"VLLM_USE_FLASHINFER_SAMPLER",
|
||||
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
|
||||
"VLLM_DISABLED_KERNELS",
|
||||
"VLLM_USE_DEEP_GEMM",
|
||||
"VLLM_USE_TRTLLM_FP4_GEMM",
|
||||
|
@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
|
||||
(self.num_qo_heads // self.num_kv_heads > 4))
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indptr_buffer=paged_kv_indptr,
|
||||
paged_kv_indices_buffer=paged_kv_indices,
|
||||
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
|
||||
use_tensor_cores=self.use_tensor_cores)
|
||||
# Tensor cores are enabled by default because the perf would be
|
||||
# atleast as good as cuda cores for all attention ops in latest
|
||||
# gpus.
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
|
||||
# save the decode wrapper
|
||||
if use_cudagraph:
|
||||
@ -984,52 +985,29 @@ def fast_plan_decode(
|
||||
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
|
||||
non_blocking=True)
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
else:
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for standard version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_cpu,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
torch.empty(0, dtype=q_data_type),
|
||||
torch.empty(0, dtype=kv_data_type),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}") from e
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_cpu,
|
||||
seq_lens_cpu,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in tensor core plan: {e}") from e
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
|
Reference in New Issue
Block a user