mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@ -5,11 +5,11 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// vllm_kernel_override_batch_invariant(); returns true
|
||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||
inline bool vllm_kernel_override_batch_invariant() {
|
||||
// vllm_is_batch_invariant(); returns true
|
||||
// if env VLLM_BATCH_INVARIANT=1
|
||||
inline bool vllm_is_batch_invariant() {
|
||||
static bool cached = []() {
|
||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||
std::string env_key = "VLLM_BATCH_INVARIANT";
|
||||
const char* val = std::getenv(env_key.c_str());
|
||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||
}();
|
||||
|
@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
|
@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
|
@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
|
||||
outputs: list[tuple[str, list]] = []
|
||||
for test_preemption in [False, True]:
|
||||
|
@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode():
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
yield
|
||||
# Restore original value after test
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
disable_custom_ar = vllm_is_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||
|
||||
try:
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
finally:
|
||||
# Restore original value
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@hopper_only
|
||||
@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
disable_custom_ar = vllm_is_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
|
@ -21,7 +21,7 @@ from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
@ -423,7 +423,7 @@ class ModelConfig:
|
||||
video_pruning_rate: float | None,
|
||||
) -> None:
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.enforce_eager = True
|
||||
|
||||
# Set the default seed to 0 in V1.
|
||||
|
@ -15,7 +15,7 @@ import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
||||
@ -565,7 +565,7 @@ class ParallelConfig:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.disable_custom_all_reduce = True
|
||||
|
||||
if (
|
||||
|
@ -20,7 +20,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||
|
||||
@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
|
@ -10,7 +10,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -103,7 +103,7 @@ class SymmMemCommunicator:
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.disabled = True
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
|
@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
|
||||
|
||||
def vllm_kernel_override_batch_invariant():
|
||||
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
|
||||
def vllm_is_batch_invariant():
|
||||
env_key = "VLLM_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
val = os.getenv(env_key, "0")
|
||||
try:
|
||||
@ -797,7 +797,7 @@ def override_envs_for_invariance():
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
override_envs_for_invariance()
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
|
@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
@ -841,7 +841,7 @@ def get_moe_configs(
|
||||
"""
|
||||
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return None
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
@ -976,7 +976,7 @@ def get_default_config(
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
@ -1136,7 +1136,7 @@ def fused_topk_bias(
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
@ -1200,7 +1200,7 @@ def grouped_topk(
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
|
@ -10,7 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -25,7 +25,7 @@ def rms_norm(
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
@ -45,7 +45,7 @@ def fused_add_rms_norm(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
|
@ -15,7 +15,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
@ -356,7 +356,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
# Dequantize FP8 weights to BF16
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
|
@ -35,7 +35,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -308,7 +308,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
def schedule(
|
||||
@ -484,7 +484,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
@ -963,7 +963,7 @@ def cascade_attention(
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
@ -988,7 +988,7 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
@ -291,7 +291,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.decode_fixed_split_size = 2048
|
||||
self.prefill_fixed_split_size = 4096
|
||||
self.disable_split_kv = True
|
||||
@ -404,7 +404,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
||||
self._workspace_buffer = torch.zeros(
|
||||
buffer_size, dtype=torch.uint8, device=self.device
|
||||
|
@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -863,7 +863,7 @@ def get_kernel_options(
|
||||
kernel_options: dict[str, int | bool] = {
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
}
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
kernel_options["BLOCK_M"] = 16
|
||||
kernel_options["BLOCK_N"] = 16
|
||||
kernel_options["IS_DIVISIBLE"] = False
|
||||
|
@ -212,7 +212,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@ -1283,7 +1283,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||
# called "return_attn_probs" instead of return_softmax_lse
|
||||
kwargs["return_attn_probs"] = return_softmax_lse
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
kwargs["num_splits"] = 1
|
||||
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
|
@ -19,7 +19,7 @@ from vllm.attention.utils.fa_utils import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
@ -110,7 +110,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
@ -181,7 +181,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
|
@ -15,7 +15,7 @@ from vllm.attention.ops.flashmla import (
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
|
@ -14,7 +14,7 @@ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
@ -163,7 +163,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||
|
||||
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
|
||||
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
|
Reference in New Issue
Block a user