mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
6 Commits
woosuk/tes
...
v0.10.1.1
Author | SHA1 | Date | |
---|---|---|---|
1da94e673c | |||
d8b736f913 | |||
3a8708f60a | |||
aab549870d | |||
ba6928cf13 | |||
befedf86a8 |
10
vllm/entrypoints/constants.py
Normal file
10
vllm/entrypoints/constants.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Shared constants for vLLM entrypoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HTTP header limits for h11 parser
|
||||||
|
# These constants help mitigate header abuse attacks
|
||||||
|
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
||||||
|
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
@ -14,6 +14,8 @@ from vllm import envs
|
|||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||||
|
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_process_using_port
|
from vllm.utils import find_process_using_port
|
||||||
@ -26,6 +28,11 @@ async def serve_http(app: FastAPI,
|
|||||||
sock: Optional[socket.socket],
|
sock: Optional[socket.socket],
|
||||||
enable_ssl_refresh: bool = False,
|
enable_ssl_refresh: bool = False,
|
||||||
**uvicorn_kwargs: Any):
|
**uvicorn_kwargs: Any):
|
||||||
|
"""
|
||||||
|
Start a FastAPI app using Uvicorn, with support for custom Uvicorn config
|
||||||
|
options. Supports http header limits via h11_max_incomplete_event_size and
|
||||||
|
h11_max_header_count.
|
||||||
|
"""
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
methods = getattr(route, "methods", None)
|
methods = getattr(route, "methods", None)
|
||||||
@ -36,7 +43,21 @@ async def serve_http(app: FastAPI,
|
|||||||
|
|
||||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||||
|
|
||||||
|
# Extract header limit options if present
|
||||||
|
h11_max_incomplete_event_size = uvicorn_kwargs.pop(
|
||||||
|
"h11_max_incomplete_event_size", None)
|
||||||
|
h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None)
|
||||||
|
|
||||||
|
# Set safe defaults if not provided
|
||||||
|
if h11_max_incomplete_event_size is None:
|
||||||
|
h11_max_incomplete_event_size = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||||
|
if h11_max_header_count is None:
|
||||||
|
h11_max_header_count = H11_MAX_HEADER_COUNT_DEFAULT
|
||||||
|
|
||||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||||
|
# Set header limits
|
||||||
|
config.h11_max_incomplete_event_size = h11_max_incomplete_event_size
|
||||||
|
config.h11_max_header_count = h11_max_header_count
|
||||||
config.load()
|
config.load()
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
_add_shutdown_handlers(app, server)
|
_add_shutdown_handlers(app, server)
|
||||||
|
@ -1894,6 +1894,8 @@ async def run_server_worker(listen_address,
|
|||||||
ssl_certfile=args.ssl_certfile,
|
ssl_certfile=args.ssl_certfile,
|
||||||
ssl_ca_certs=args.ssl_ca_certs,
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
|
||||||
|
h11_max_header_count=args.h11_max_header_count,
|
||||||
**uvicorn_kwargs,
|
**uvicorn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ from vllm.config import config
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
|
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||||
|
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -172,6 +174,12 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
|||||||
enable_log_outputs: bool = False
|
enable_log_outputs: bool = False
|
||||||
"""If set to True, enable logging of model outputs (generations)
|
"""If set to True, enable logging of model outputs (generations)
|
||||||
in addition to the input logging that is enabled by default."""
|
in addition to the input logging that is enabled by default."""
|
||||||
|
h11_max_incomplete_event_size: int = H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
|
||||||
|
"""Maximum size (bytes) of an incomplete HTTP event (header or body) for
|
||||||
|
h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB)."""
|
||||||
|
h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT
|
||||||
|
"""Maximum number of HTTP headers allowed in a request for h11 parser.
|
||||||
|
Helps mitigate header abuse. Default: 256."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
@ -20,7 +20,15 @@ from openai.types.chat.chat_completion_message import (
|
|||||||
from openai.types.responses import (ResponseFunctionToolCall,
|
from openai.types.responses import (ResponseFunctionToolCall,
|
||||||
ResponseInputItemParam, ResponseOutputItem,
|
ResponseInputItemParam, ResponseOutputItem,
|
||||||
ResponsePrompt, ResponseReasoningItem,
|
ResponsePrompt, ResponseReasoningItem,
|
||||||
ResponseStatus, ResponseTextConfig)
|
ResponseStatus)
|
||||||
|
|
||||||
|
# Backward compatibility for OpenAI client versions
|
||||||
|
try: # For older openai versions (< 1.100.0)
|
||||||
|
from openai.types.responses import ResponseTextConfig
|
||||||
|
except ImportError: # For newer openai versions (>= 1.100.0)
|
||||||
|
from openai.types.responses import (ResponseFormatTextConfig as
|
||||||
|
ResponseTextConfig)
|
||||||
|
|
||||||
from openai.types.responses.response import ToolChoice
|
from openai.types.responses.response import ToolChoice
|
||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
from openai.types.shared import Metadata, Reasoning
|
from openai.types.shared import Metadata, Reasoning
|
||||||
|
@ -208,15 +208,10 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
"valid JSON object in tool '%s', will try other "
|
"valid JSON object in tool '%s', will try other "
|
||||||
"methods to parse it.", param_value, param_name,
|
"methods to parse it.", param_value, param_name,
|
||||||
func_name)
|
func_name)
|
||||||
try:
|
logger.warning(
|
||||||
converted_value = eval(param_value)
|
"Parameter '%s' has unknown type '%s'. "
|
||||||
return converted_value
|
"The value will be treated as a string.", param_name,
|
||||||
except Exception:
|
param_type)
|
||||||
logger.warning(
|
|
||||||
"Parsed value '%s' of parameter '%s' cannot be "
|
|
||||||
"converted via Python `eval()` in tool '%s', "
|
|
||||||
"degenerating to string.", param_value, param_name,
|
|
||||||
func_name)
|
|
||||||
return param_value
|
return param_value
|
||||||
|
|
||||||
# Extract function name
|
# Extract function name
|
||||||
|
@ -762,11 +762,11 @@ class FusedMoE(CustomOp):
|
|||||||
self.global_num_experts = num_experts + num_redundant_experts
|
self.global_num_experts = num_experts + num_redundant_experts
|
||||||
|
|
||||||
# we padding globally so EP buffer allocation works
|
# we padding globally so EP buffer allocation works
|
||||||
if (quant_config and quant_config.get_name() == "mxfp4"
|
if quant_config and quant_config.get_name() == "mxfp4":
|
||||||
and (current_platform.is_rocm()
|
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
should_use_flashinfer_mxfp4)
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
|
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
|
|
||||||
# For smuggling this layer into the fused moe custom op
|
# For smuggling this layer into the fused moe custom op
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||||
@ -26,12 +27,38 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
||||||
next_power_of_2, round_up)
|
next_power_of_2, round_up)
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
|
|
||||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
logger = init_logger(__name__)
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
|
||||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
|
||||||
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
|
def _should_use_flashinfer_mxfp4_bf16():
|
||||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
"""Determine if FlashInfer MXFP4 BF16 should be used."""
|
||||||
|
# If explicitly set, respect the setting
|
||||||
|
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||||
|
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
||||||
|
|
||||||
|
# Enable by default on SM100 if MXFP8 is not explicitly enabled
|
||||||
|
if (current_platform.is_device_capability(100) and has_flashinfer()
|
||||||
|
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
|
||||||
|
logger.info_once(
|
||||||
|
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
|
||||||
|
"For faster performance, consider setting "
|
||||||
|
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
|
||||||
|
"though this may impact accuracy.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _should_use_flashinfer_mxfp4_mxfp8():
|
||||||
|
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
|
||||||
|
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_flashinfer_mxfp4():
|
||||||
|
return (_should_use_flashinfer_mxfp4_mxfp8()
|
||||||
|
or _should_use_flashinfer_mxfp4_bf16())
|
||||||
|
|
||||||
|
|
||||||
class Mxfp4Config(QuantizationConfig):
|
class Mxfp4Config(QuantizationConfig):
|
||||||
@ -87,12 +114,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
self.moe = moe
|
self.moe = moe
|
||||||
self.use_marlin = self._should_use_marlin()
|
self.use_marlin = self._should_use_marlin()
|
||||||
|
|
||||||
|
if current_platform.is_device_capability(100) and not has_flashinfer():
|
||||||
|
logger.warning_once(
|
||||||
|
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
||||||
|
"is not available. This may result in degraded performance. "
|
||||||
|
"Please `pip install vllm[flashinfer]` for best results.")
|
||||||
|
|
||||||
def _should_use_marlin(self):
|
def _should_use_marlin(self):
|
||||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||||
return envs.VLLM_MXFP4_USE_MARLIN
|
return envs.VLLM_MXFP4_USE_MARLIN
|
||||||
if current_platform.is_cuda() and \
|
if current_platform.is_cuda() and \
|
||||||
not current_platform.has_device_capability(100):
|
not current_platform.is_device_capability(100):
|
||||||
if not current_platform.is_device_capability(90):
|
if not current_platform.has_device_capability(90):
|
||||||
# marlin kernel has better performance on ampere
|
# marlin kernel has better performance on ampere
|
||||||
return True
|
return True
|
||||||
if not has_triton_kernels():
|
if not has_triton_kernels():
|
||||||
@ -138,8 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.hidden_size = hidden_size
|
layer.hidden_size = hidden_size
|
||||||
layer.intermediate_size_per_partition = \
|
layer.intermediate_size_per_partition = \
|
||||||
intermediate_size_per_partition_after_pad
|
intermediate_size_per_partition_after_pad
|
||||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
elif should_use_flashinfer_mxfp4():
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
|
||||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||||
# for to hold non-uniform sharded tensor as well as swizzling
|
# for to hold non-uniform sharded tensor as well as swizzling
|
||||||
# other padding to increase performance
|
# other padding to increase performance
|
||||||
@ -230,8 +262,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
elif should_use_flashinfer_mxfp4():
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -478,11 +510,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_replica_count), (
|
logical_replica_count), (
|
||||||
"MXFP4 are not supported with this configuration.")
|
"MXFP4 are not supported with this configuration.")
|
||||||
|
|
||||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
if should_use_flashinfer_mxfp4():
|
||||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
||||||
assert not self.moe.use_ep, (
|
assert not self.moe.use_ep, (
|
||||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||||
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
if _should_use_flashinfer_mxfp4_bf16():
|
||||||
assert x.dtype == torch.bfloat16
|
assert x.dtype == torch.bfloat16
|
||||||
x_quant = x
|
x_quant = x
|
||||||
x_scale = None
|
x_scale = None
|
||||||
|
@ -21,7 +21,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
# enable full CUDA Graph support for decode-only capture
|
# enable full CUDA Graph support for decode-only capture
|
||||||
attn_cudagraph_support: ClassVar[
|
cudagraph_support: ClassVar[
|
||||||
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user