[Core] Always use tensor cores for Flashinfer Decode Wrapper (#23214)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-08-21 13:02:11 -07:00
committed by GitHub
parent 3496274663
commit 1d353b6352
5 changed files with 31 additions and 64 deletions

View File

@ -110,7 +110,7 @@ def benchmark_decode(
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
use_tensor_cores=True,
)
wrapper.plan(
kv_indptr,

View File

@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
use_tensor_cores=True)
wrapper.plan(
kv_indptr,
kv_indices,
@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
use_tensor_cores = True
kv_cache_dtype = torch.float8_e4m3fn
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)

View File

@ -136,9 +136,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4))
workspace_buffer, kv_layout, use_tensor_cores=True)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,

View File

@ -42,7 +42,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
@ -465,11 +464,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
@ -1221,7 +1215,6 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_FLASHINFER_FORCE_TENSOR_CORES",
"VLLM_DISABLED_KERNELS",
"VLLM_USE_DEEP_GEMM",
"VLLM_USE_TRTLLM_FP4_GEMM",

View File

@ -13,7 +13,6 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
@ -228,8 +227,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.q_data_type = self.kv_cache_dtype
else:
self.kv_cache_dtype = self.kv_cache_spec.dtype
self.use_tensor_cores = (envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or
(self.num_qo_heads // self.num_kv_heads > 4))
self._cascade_wrapper = None # Wrapper for cascade attention
@ -308,7 +305,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=self.use_tensor_cores)
# Tensor cores are enabled by default because the perf would be
# atleast as good as cuda cores for all attention ops in latest
# gpus.
use_tensor_cores=True,
)
# save the decode wrapper
if use_cudagraph:
@ -984,52 +985,29 @@ def fast_plan_decode(
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu,
non_blocking=True)
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
else:
try:
# Make sure we pass exactly 15 arguments for standard version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_cpu,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
torch.empty(0, dtype=q_data_type),
torch.empty(0, dtype=kv_data_type),
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
try:
# Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left