mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Fix CompilationConfig repr (#19091)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@ -6,6 +6,7 @@ from typing import Literal, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||
config, get_field)
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
@ -44,6 +45,18 @@ def test_config(test_config, expected_error):
|
||||
config(test_config)
|
||||
|
||||
|
||||
def test_compile_config_repr_succeeds():
|
||||
# setup: VllmBackend mutates the config object
|
||||
config = VllmConfig()
|
||||
backend = VllmBackend(config)
|
||||
backend.configure_post_pass()
|
||||
|
||||
# test that repr(config) succeeds
|
||||
val = repr(config)
|
||||
assert 'VllmConfig' in val
|
||||
assert 'inductor_passes' in val
|
||||
|
||||
|
||||
def test_get_field():
|
||||
|
||||
@dataclass
|
||||
|
@ -648,7 +648,7 @@ class ModelConfig:
|
||||
def maybe_pull_model_tokenizer_for_s3(self, model: str,
|
||||
tokenizer: str) -> None:
|
||||
"""Pull model/tokenizer from S3 to temporary directory when needed.
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name or path
|
||||
tokenizer: Tokenizer name or path
|
||||
@ -1370,9 +1370,9 @@ class ModelConfig:
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
"""
|
||||
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||
True to enable cross-attention
|
||||
Neuron needs all multimodal data to be in the decoder and does not
|
||||
Neuron needs all multimodal data to be in the decoder and does not
|
||||
need to explicitly enable cross-attention
|
||||
"""
|
||||
if (current_platform.is_neuron()
|
||||
@ -1794,7 +1794,7 @@ class ParallelConfig:
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||
Only support LLama4 for now"""
|
||||
|
||||
@property
|
||||
@ -2272,9 +2272,9 @@ class DeviceConfig:
|
||||
|
||||
device: SkipValidation[Union[Device, torch.device]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
@ -4007,19 +4007,24 @@ class CompilationConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
exclude = {
|
||||
"static_forward_context",
|
||||
"enabled_custom_ops",
|
||||
"disabled_custom_ops",
|
||||
"compilation_time",
|
||||
"bs_to_padded_graph_size",
|
||||
"pass_config",
|
||||
"traced_files",
|
||||
"static_forward_context": True,
|
||||
"enabled_custom_ops": True,
|
||||
"disabled_custom_ops": True,
|
||||
"compilation_time": True,
|
||||
"bs_to_padded_graph_size": True,
|
||||
"pass_config": True,
|
||||
"traced_files": True,
|
||||
"inductor_compile_config": {
|
||||
"post_grad_custom_post_pass": True,
|
||||
},
|
||||
}
|
||||
# The cast to string is necessary because Pydantic is mocked in docs
|
||||
# builds and sphinx-argparse doesn't know the return type of decode()
|
||||
return str(
|
||||
TypeAdapter(CompilationConfig).dump_json(
|
||||
self, exclude=exclude, exclude_unset=True).decode())
|
||||
self,
|
||||
exclude=exclude, # type: ignore[arg-type]
|
||||
exclude_unset=True).decode())
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
Reference in New Issue
Block a user