[Doc] Fix top-level API links/docs (#18621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-24 00:46:56 +08:00
committed by GitHub
parent 8ddd1cf26a
commit 273cb3b4d9
9 changed files with 88 additions and 103 deletions

View File

@ -129,16 +129,17 @@ class BenchmarkDataset(ABC):
Args:
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
LoRA is selected. max_loras (Optional[int]): The maximum number of
LoRAs available. If None, LoRA is not used. lora_path
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
is not used.
LoRA is selected.
max_loras (Optional[int]): The maximum number of LoRAs available.
If `None`, LoRA is not used.
lora_path (Optional[str]): Path to the LoRA parameters on disk.
If `None`, LoRA is not used.
Returns:
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
element is a LoRARequest (or None if not applicable) and the second
element is the tokenizer associated with the LoRA request (or the
base tokenizer).
A tuple with the following elements:
- A new [LoRARequest][] (or `None` if not applicable).
- The tokenizer associated with the LoRA request
(or the base tokenizer).
"""
if max_loras is None or lora_path is None:
return None, tokenizer
@ -167,7 +168,7 @@ class BenchmarkDataset(ABC):
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
Returns:
@ -184,7 +185,8 @@ class BenchmarkDataset(ABC):
Args:
requests (List[SampleRequest]): The current list of sampled
requests. num_requests (int): The target number of requests.
requests.
num_requests (int): The target number of requests.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)

View File

@ -4552,7 +4552,7 @@ def contains_object_print(text):
text (str): The text to check
Returns:
bool: True if a match is found, False otherwise
result (bool): `True` if a match is found, `False` otherwise.
"""
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
match = re.search(pattern, text)

View File

@ -210,10 +210,11 @@ class KVConnectorBase_V1(ABC):
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
A tuple with the following elements:
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
pass

View File

@ -118,11 +118,11 @@ class PyNcclPipe(KVPipeBase):
"""
Create the metadata as a dictionary based on the input tensor.
Parameters:
- tensor: The input tensor or None if no tensor is provided.
Args:
tensor: The input tensor or None if no tensor is provided.
Returns:
- metadata: A dictionary with the following keys:
metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
@ -135,13 +135,13 @@ class PyNcclPipe(KVPipeBase):
"""
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Args:
metadata: A dictionary with keys "dtype" and "shape",
describing the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
self.device.
buffer: A tensor of the specified type and shape,
allocated on `self.device`.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
@ -151,8 +151,8 @@ class PyNcclPipe(KVPipeBase):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
Args:
metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
@ -161,8 +161,8 @@ class PyNcclPipe(KVPipeBase):
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
metadata: A dictionary with keys "dtype" and "shape"
describing the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
@ -171,9 +171,9 @@ class PyNcclPipe(KVPipeBase):
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
Args:
tensor: The input tensor to be sent, or `None` if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
@ -187,7 +187,7 @@ class PyNcclPipe(KVPipeBase):
the target rank.
Returns:
- buffer: The received tensor, or None if no tensor is received.
buffer: The received tensor, or `None` if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
@ -227,8 +227,8 @@ class PyNcclPipe(KVPipeBase):
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
Args:
tensor: The tensor to send, or `None` if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
@ -250,8 +250,8 @@ class PyNcclPipe(KVPipeBase):
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
Args:
tensor: The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)

View File

@ -130,26 +130,16 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The {class}`~vllm.LLM` class wraps this class for offline batched inference
and the {class}`AsyncLLMEngine` class wraps this class for online serving.
The [LLM][vllm.LLM] class wraps this class for offline batched inference
and the [AsyncLLMEngine][] class wraps this class for online serving.
The config arguments are derived from {class}`~vllm.EngineArgs`. (See
{ref}`engine-args`)
The config arguments are derived from [EngineArgs][vllm.EngineArgs]. (See
[engine-args][])
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
speculative_config (Optional): The configuration related to speculative
decoding.
vllm_config: The configuration for initializing and running vLLM.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
@ -695,11 +685,12 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
prompt: The prompt to the LLM. See
[PromptType][vllm.inputs.PromptType]
for more details about the format of each input.
params: Parameters for sampling or pooling.
{class}`~vllm.SamplingParams` for text generation.
{class}`~vllm.PoolingParams` for pooling.
[SamplingParams][vllm.SamplingParams] for text generation.
[PoolingParams][vllm.PoolingParams] for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
lora_request: The LoRA request to add.
@ -711,10 +702,11 @@ class LLMEngine:
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `n` number of {class}`~vllm.Sequence` objects.
- Create a {class}`~vllm.SequenceGroup` object
from the list of {class}`~vllm.Sequence`.
- Add the {class}`~vllm.SequenceGroup` object to the scheduler.
- Create `n` number of [Sequence][vllm.Sequence] objects.
- Create a [SequenceGroup][vllm.SequenceGroup] object
from the list of [Sequence][vllm.Sequence].
- Add the [SequenceGroup][vllm.SequenceGroup] object to the
scheduler.
Example:
>>> # initialize engine
@ -861,9 +853,7 @@ class LLMEngine:
request_id: The ID(s) of the request to abort.
Details:
- Refer to the
{meth}`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class {class}`~vllm.core.scheduler.Scheduler`.
- Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][].
Example:
>>> # initialize engine and add a request with request_id

View File

@ -116,7 +116,8 @@ class LLM:
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See {class}`~vllm.config.ParallelConfig`
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
@ -128,12 +129,12 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for {class}`~vllm.EngineArgs`. (See
{ref}`engine-args`)
**kwargs: Arguments for [EngineArgs][vllm.EngineArgs]. (See
[engine-args][])
Note:
This class is intended to be used for offline inference. For online
serving, use the {class}`~vllm.AsyncLLMEngine` class instead.
serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
"""
DEPRECATE_LEGACY: ClassVar[bool] = True
@ -142,7 +143,7 @@ class LLM:
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
{meth}`LLM.__init__`.
[LLM.__init__][].
"""
@classmethod
@ -403,7 +404,7 @@ class LLM:
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType`
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
@ -669,7 +670,7 @@ class LLM:
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the {meth}`generate` method to generate the
tokenizer and calls the [generate][] method to generate the
responses.
Multi-modal inputs can be passed in the same way you would pass them
@ -678,8 +679,8 @@ class LLM:
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
@ -689,27 +690,27 @@ class LLM:
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: ``"Who are you?"``
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: ``[{"type": "text", "text": "Who are you?"}]``
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``.
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of ``RequestOutput`` objects containing the generated
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages: list[list[ChatCompletionMessageParam]]
@ -908,7 +909,7 @@ class LLM:
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType`
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
@ -997,7 +998,7 @@ class LLM:
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType`
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
@ -1007,7 +1008,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "embed":
@ -1041,7 +1042,7 @@ class LLM:
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See {class}`~vllm.inputs.PromptType`
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
@ -1049,7 +1050,7 @@ class LLM:
generation, if any.
Returns:
A list of ``ClassificationRequestOutput`` objects containing the
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
if self.llm_engine.model_config.task != "classify":
@ -1159,11 +1160,11 @@ class LLM:
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs ``<text,text_pair>``.
"""Generate similarity scores for all pairs `<text,text_pair>`.
The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
times to pair with the ``text_2`` sentences.
The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
In the `1 - N` case the `text_1` sentence will be replicated `N`
times to pair with the `text_2` sentences.
The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all
@ -1171,9 +1172,9 @@ class LLM:
Args:
text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the ``text_2`` list
case it has to have the same length as the `text_2` list
text_2: The texts to pair with the query to form the input
to the LLM. See {class}`~vllm.inputs.PromptType` for
to the LLM. See [PromptType][vllm.inputs.PromptType] for
more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
@ -1181,7 +1182,7 @@ class LLM:
generation, if any.
Returns:
A list of ``ScoringRequestOutput`` objects containing the
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
"""
runner_type = self.llm_engine.model_config.runner_type
@ -1282,13 +1283,13 @@ class LLM:
def wake_up(self, tags: Optional[list[str]] = None):
"""
Wake up the engine from sleep mode. See the {meth}`sleep` method
Wake up the engine from sleep mode. See the [sleep][] method
for more details.
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
`("weights", "kv_cache")`. If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""

View File

@ -12,7 +12,7 @@ The global {class}`~MultiModalRegistry` is used by model runners to
dispatch data processing according to the target model.
Info:
{ref}`mm-processing`
[mm-processing][]
"""
__all__ = [

View File

@ -215,7 +215,7 @@ class MultiModalRegistry:
invoked to transform the data into a dictionary of model inputs.
Info:
{ref}`mm-processing`
[mm-processing][]
"""
def wrapper(model_cls: N) -> N:
@ -260,7 +260,7 @@ class MultiModalRegistry:
Create a multi-modal processor for a specific model and tokenizer.
Info:
{ref}`mm-processing`
[mm-processing][]
"""
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")

View File

@ -391,15 +391,6 @@ class PoolingRequestOutput(Generic[_O]):
prompt_token_ids, finished)
def __repr__(self):
"""
Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the pooling request's results.
Returns:
str: A string representation of the PoolingRequestOutput instance.
"""
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "