[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)

Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
Bram Wasti
2025-10-16 14:40:25 -07:00
committed by GitHub
parent 23583ee28c
commit b2f78cbad4
20 changed files with 61 additions and 61 deletions

View File

@ -5,11 +5,11 @@
namespace vllm {
// vllm_kernel_override_batch_invariant(); returns true
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
inline bool vllm_kernel_override_batch_invariant() {
// vllm_is_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1
inline bool vllm_is_batch_invariant() {
static bool cached = []() {
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
std::string env_key = "VLLM_BATCH_INVARIANT";
const char* val = std::getenv(env_key.c_str());
return (val && std::atoi(val) != 0) ? 1 : 0;
}();

View File

@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
LAUNCH_FUSED_POLY_NORM(8);
} else {

View File

@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8);

View File

@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list]] = []
for test_preemption in [False, True]:

View File

@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "1"
yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")
@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os.environ["VLLM_ATTENTION_BACKEND"] = backend
# CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "0"
try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value
@hopper_only
@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")

View File

@ -21,7 +21,7 @@ from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
@ -423,7 +423,7 @@ class ModelConfig:
video_pruning_rate: float | None,
) -> None:
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.enforce_eager = True
# Set the default seed to 0 in V1.

View File

@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
@ -565,7 +565,7 @@ class ParallelConfig:
from vllm.executor.executor_base import ExecutorBase
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.disable_custom_all_reduce = True
if (

View File

@ -20,7 +20,7 @@ import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.utils import cuda_device_count_stateless, update_environment_variables
@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled,
)
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return False
if not is_symmetric_memory_enabled():

View File

@ -10,7 +10,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
)
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
@ -103,7 +103,7 @@ class SymmMemCommunicator:
return
self.force_multimem = force_multimem
self.disabled = False
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.disabled = True
def should_use_symm_mem(self, inp: torch.Tensor):

View File

@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
def vllm_kernel_override_batch_invariant():
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
try:
@ -797,7 +797,7 @@ def override_envs_for_invariance():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
override_envs_for_invariance()
enable_batch_invariant_mode()

View File

@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
@ -841,7 +841,7 @@ def get_moe_configs(
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return None
# First look up if an optimized configuration is available in the configs
@ -976,7 +976,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
@ -1136,7 +1136,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
@ -1200,7 +1200,7 @@ def grouped_topk(
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]

View File

@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -25,7 +25,7 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
@ -45,7 +45,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual

View File

@ -15,7 +15,7 @@ from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@ -356,7 +356,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)

View File

@ -35,7 +35,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
@ -308,7 +308,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
max_num_splits = 1
def schedule(
@ -484,7 +484,7 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
self.batch_invariant_enabled = vllm_is_batch_invariant()
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError(
@ -963,7 +963,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else 0,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@ -988,7 +988,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
num_splits=1 if vllm_is_batch_invariant() else 0,
)
# Merge prefix and suffix outputs, and store the result in output.

View File

@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
@ -291,7 +291,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096
self.disable_split_kv = True
@ -404,7 +404,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
buffer_size, dtype=torch.uint8, device=self.device

View File

@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (
@ -863,7 +863,7 @@ def get_kernel_options(
kernel_options: dict[str, int | bool] = {
"FORCE_USE_FLEX_ATTENTION": True,
}
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
kernel_options["BLOCK_M"] = 16
kernel_options["BLOCK_N"] = 16
kernel_options["IS_DIVISIBLE"] = False

View File

@ -212,7 +212,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -1283,7 +1283,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(

View File

@ -19,7 +19,7 @@ from vllm.attention.utils.fa_utils import (
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
@ -110,7 +110,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.max_num_splits = 1
def _schedule_decode(
@ -181,7 +181,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata(

View File

@ -15,7 +15,7 @@ from vllm.attention.ops.flashmla import (
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
device = q.device
dtype = torch.int32

View File

@ -14,7 +14,7 @@ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
@ -163,7 +163,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(