[Doc] Convert Sphinx directives ( {class}, {meth}, {attr}, ...) to MkDocs format for better documentation linking (#18663)

Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
Hyogeun Oh (오효근)
2025-05-27 17:44:20 +09:00
committed by GitHub
parent 6881107948
commit a68e293cb9
37 changed files with 360 additions and 247 deletions

View File

@ -39,7 +39,8 @@ class CompilerInterface:
Gather all the relevant information from the vLLM config,
to compute a hash so that we can cache the compiled model.
See {meth}`VllmConfig.compute_hash` to check what information
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
to check what information
is already considered by default. This function should only
consider the information that is specific to the compiler.
"""

View File

@ -2986,7 +2986,7 @@ class PoolerConfig:
pooling_type: Optional[str] = None
"""
The pooling method of the pooling model. This should be a key in
{class}`vllm.model_executor.layers.pooler.PoolingType`.
[`vllm.model_executor.layers.pooler.PoolingType`][].
"""
normalize: Optional[bool] = None
@ -3697,23 +3697,27 @@ class CompilationConfig:
"""Configuration for compilation. It has three parts:
- Top-level Compilation control:
- {attr}`level`
- {attr}`debug_dump_path`
- {attr}`cache_dir`
- {attr}`backend`
- {attr}`custom_ops`
- {attr}`splitting_ops`
- [`level`][vllm.config.CompilationConfig.level]
- [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path]
- [`cache_dir`][vllm.config.CompilationConfig.cache_dir]
- [`backend`][vllm.config.CompilationConfig.backend]
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- {attr}`use_cudagraph`
- {attr}`cudagraph_capture_sizes`
- {attr}`cudagraph_num_of_warmups`
- {attr}`cudagraph_copy_inputs`
- {attr}`full_cuda_graph`
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- Inductor compilation:
- {attr}`use_inductor`
- {attr}`compile_sizes`
- {attr}`inductor_compile_config`
- {attr}`inductor_passes`
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
- [`inductor_compile_config`]
[vllm.config.CompilationConfig.inductor_compile_config]
- [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes]
- custom inductor passes
Why we have different sizes for cudagraph and inductor:

View File

@ -167,4 +167,7 @@ class HTTPConnection:
global_http_connection = HTTPConnection()
"""The global {class}`HTTPConnection` instance used by vLLM."""
"""
The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used
by vLLM.
"""

View File

@ -475,7 +475,8 @@ class _AsyncLLMEngine(LLMEngine):
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of {meth}`add_request`."""
"""Async version of
[`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@ -582,20 +583,21 @@ async def build_guided_decoding_logits_processor_async(
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for {class}`LLMEngine`.
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
This class is used to wrap the {class}`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The {class}`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the {class}`LLMEngine` to the caller.
This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to
make it asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked
by the generate method when there are requests in the waiting queue. The
generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine]
to the caller.
Args:
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for {class}`LLMEngine`.
**kwargs: Arguments for {class}`LLMEngine`.
*args: Arguments for [`LLMEngine`][vllm.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine].
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
@ -985,8 +987,9 @@ class AsyncLLMEngine(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
@ -1003,7 +1006,7 @@ class AsyncLLMEngine(EngineClient):
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
{meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
[`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
@ -1075,8 +1078,9 @@ class AsyncLLMEngine(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
@ -1089,15 +1093,15 @@ class AsyncLLMEngine(EngineClient):
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
{meth}`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
- If the engine is not running, start the background loop,
which iteratively invokes
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
```

View File

@ -130,11 +130,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The [LLM][vllm.LLM] class wraps this class for offline batched inference
and the [AsyncLLMEngine][] class wraps this class for online serving.
The [`LLM`][vllm.LLM] class wraps this class for offline batched inference
and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine]
class wraps this class for online serving.
The config arguments are derived from [EngineArgs][vllm.EngineArgs]. (See
[engine-args][])
The config arguments are derived from [`EngineArgs`][vllm.EngineArgs].
Args:
vllm_config: The configuration for initializing and running vLLM.

View File

@ -492,8 +492,9 @@ class MQLLMEngineClient(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
@ -561,8 +562,9 @@ class MQLLMEngineClient(EngineClient):
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See {class}`~vllm.inputs.PromptType`
for more details about the format of each input.
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.

View File

@ -42,19 +42,22 @@ HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for {class}`LLMEngine`.
"""A multiprocessing wrapper for
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
This class is used to wrap the {class}`LLMEngine` class to enable use
This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The {class}`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
process is kicked off when a new RPCProcessRequest is received by the
input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
{class}`LLMEngine.step()`, and sends the RequestOutputs back over
the output_socket.
[`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends
the RequestOutputs back over the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
@ -65,8 +68,8 @@ class MQLLMEngine:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for {class}`LLMEngine`.
**kwargs: Arguments for {class}`LLMEngine`.
*args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
**kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine].
"""
def __init__(self,

View File

@ -56,8 +56,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
scheduled computation.
Args:
seq_group: the outputs are associated with this {class}`SequenceGroup`
outputs: the {class}`SequenceGroupOutput`s for all scheduler steps
seq_group: the outputs are associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
for all scheduler steps
"""
for output in outputs:
# Concatenate single-step prompt logprob processing results.

View File

@ -19,17 +19,21 @@ logger = init_logger(__name__)
def single_step_process_prompt_logprob(
sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
output: CompletionSequenceGroupOutput) -> None:
"""Process prompt logprobs associated with the {class}`SequenceGroupOutput`
for a given step.
"""Process prompt logprobs associated with the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: {class}`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this {class}`SequenceGroup`
output: the {class}`SequenceGroupOutput` for a single scheduler step
sg_output_proc:
[`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
instance
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
prompt_logprobs = output.prompt_logprobs
@ -103,8 +107,11 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
scheduled computation.
Args:
seq_group: the output is associated with this {class}`SequenceGroup`
outputs: the {class}`SequenceGroupOutput` for a single scheduler step
seq_group: the output is associated with this
[`SequenceGroup`][vllm.sequence.SequenceGroup]
outputs: the
[`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
for a single scheduler step
"""
assert len(outputs) == 1, "Single step should only have 1 output."
output = outputs[0]

View File

@ -129,8 +129,7 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the level of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
**kwargs: Arguments for [EngineArgs][vllm.EngineArgs]. (See
[engine-args][])
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note:
This class is intended to be used for offline inference. For online
@ -494,7 +493,7 @@ class LLM:
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
{exc}`TimeoutError` on timeout. `None` means wait indefinitely.
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.

View File

@ -582,7 +582,8 @@ class OpenAIServing:
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input.
"""
return next(
@ -603,7 +604,8 @@ class OpenAIServing:
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes multiple inputs.
"""
for text in prompt_inputs:

View File

@ -74,7 +74,7 @@ class ExecutorBase(ABC):
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
{exc}`TimeoutError` on timeout. `None` means wait indefinitely.
[`TimeoutError`][] on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.

View File

@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
INPUT_REGISTRY = InputRegistry()
"""
The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine`
to dispatch data processing according to the target model.
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
__all__ = [

View File

@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type {class}`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
according to any of the {class}`SingletonPrompt` schemas,
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an {class}`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
{class}`SingletonPrompt` instances.
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
"""
encoder_prompt: _T1_co
@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
- A single data structure containing both an encoder and a decoder prompt
({class}`ExplicitEncoderDecoderPrompt`)
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
"""
@ -189,7 +193,8 @@ def token_inputs(
prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs:
"""Construct {class}`TokenInputs` from optional values."""
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
@ -221,7 +226,8 @@ def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if cache_salt is not None:
@ -232,7 +238,7 @@ def embeds_inputs(
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
The inputs in {class}`~vllm.LLMEngine` before they are
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
class EncoderDecoderInputs(TypedDict):
"""
The inputs in {class}`~vllm.LLMEngine` before they are
passed to the model executor.
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
are passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder: Union[TokenInputs, "MultiModalInputs"]
"""The inputs for the encoder portion."""
@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
A processed {class}`SingletonPrompt` which can be passed to
{class}`vllm.sequence.Sequence`.
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`vllm.sequence.Sequence`][].
"""
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to {data}`vllm.inputs.InputProcessor`.
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs)
mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts(
@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
{class}`ExplicitEncoderDecoderPrompt` instances.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
@ -299,10 +308,11 @@ def zip_enc_dec_prompts(
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt, decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs))
for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
encoder_prompt,
decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs),
) for (encoder_prompt,
decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,

View File

@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
@overload
def parse_and_batch_prompt(
prompt: Union[str, list[str]]) -> Sequence[ParsedText]:
prompt: Union[str, list[str]], ) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]:
prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]:
...
@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
type: Literal["embeds"]
content: EmbedsPrompt
@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt

View File

@ -67,11 +67,11 @@ class InputPreprocessor:
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
'''
"""
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
"""
if not self.model_config.is_encoder_decoder:
logger.warning_once(
@ -79,14 +79,14 @@ class InputPreprocessor:
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
if self.model_config is None or self.model_config.hf_config is None:
logger.warning_once(
"Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
"decoder_start_token_id", None)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token "
@ -97,7 +97,7 @@ class InputPreprocessor:
return dec_start_token_id
def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
'''
"""
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
@ -126,7 +126,7 @@ class InputPreprocessor:
Returns:
* prompt_token_ids
'''
"""
bos_token_id = self.get_bos_token_id()
assert bos_token_id is not None
@ -224,7 +224,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""Async version of {meth}`_tokenize_prompt`."""
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer_group()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
@ -287,7 +290,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""Async version of {meth}`_process_multimodal`."""
"""
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
@ -472,7 +478,7 @@ class InputPreprocessor:
Returns:
* {class}`SingletonInputs` instance
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
"""
parsed = parse_singleton_prompt(prompt)
@ -508,7 +514,10 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
"""Async version of {meth}`_prompt_to_llm_inputs`."""
"""
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
@ -644,7 +653,9 @@ class InputPreprocessor:
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
Process an input prompt into an {class}`EncoderDecoderInputs` instance.
Process an input prompt into an
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
singleton prompts which carry only the
@ -670,7 +681,8 @@ class InputPreprocessor:
Returns:
* {class}`EncoderDecoderInputs` instance
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
@ -710,7 +722,10 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> EncoderDecoderInputs:
"""Async version of {meth}`_process_encoder_decoder_prompt`."""
"""
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
@ -778,7 +793,8 @@ class InputPreprocessor:
) -> DecoderOnlyInputs:
"""
For decoder-only models:
Process an input prompt into an {class}`DecoderOnlyInputs` instance.
Process an input prompt into a
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
Arguments:
@ -789,7 +805,7 @@ class InputPreprocessor:
Returns:
* {class}`DecoderOnlyInputs` instance
* [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
"""
prompt_comps = self._prompt_to_llm_inputs(
@ -812,7 +828,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> DecoderOnlyInputs:
"""Async version of {meth}`_process_decoder_only_prompt`."""
"""
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
@ -863,7 +882,10 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False,
) -> ProcessorInputs:
"""Async version of {meth}`preprocess`."""
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
assert not return_mm_hashes, (
"Multimodal hashes for encoder-decoder models should not be ",

View File

@ -38,7 +38,7 @@ class InputContext:
) -> _C:
"""
Get the HuggingFace configuration
({class}`transformers.PretrainedConfig`) of the model,
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
@ -79,7 +79,7 @@ class InputContext:
) -> _P:
"""
Get the HuggingFace processor
({class}`transformers.ProcessorMixin`) of the model,
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:

View File

@ -68,22 +68,22 @@ class _VllmLogger(Logger):
"""
Note:
This class is just to provide type information.
We actually patch the methods directly on the {class}`logging.Logger`
We actually patch the methods directly on the [`logging.Logger`][]
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
"""
def info_once(self, msg: str, *args: Hashable) -> None:
"""
As {meth}`info`, but subsequent calls with the same message
are silently dropped.
As [`info`][logging.Logger.info], but subsequent calls with
the same message are silently dropped.
"""
_print_info_once(self, msg, *args)
def warning_once(self, msg: str, *args: Hashable) -> None:
"""
As {meth}`warning`, but subsequent calls with the same message
are silently dropped.
As [`warning`][logging.Logger.warning], but subsequent calls with
the same message are silently dropped.
"""
_print_warning_once(self, msg, *args)

View File

@ -228,17 +228,19 @@ class Sampler(nn.Module):
) -> Optional[SamplerOutput]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the {class}`SamplerOutput` structure
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the
[`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput]
structure
Args:
logits: (num_tokens, vocab_size).

View File

@ -226,9 +226,11 @@ class SupportsPP(Protocol):
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[Tensor, "IntermediateTensors"]:
"""
Accept {class}`IntermediateTensors` when PP rank > 0.
Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
PP rank > 0.
Return {class}`IntermediateTensors` only for the last PP rank.
Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
for the last PP rank.
"""
...

View File

@ -965,7 +965,7 @@ def select_tiling(
class MolmoProcessorWrapper:
"""
Wraps {class}`MolmoProcessor` so that it can be called directly.
Wraps `MolmoProcessor` so that it can be called directly.
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py

View File

@ -67,14 +67,14 @@ class PixtralImagePixelInputs(TypedDict):
"""
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
The result of stacking {attr}`ImageEncoding.tokens` from each prompt.
The result of stacking `ImageEncoding.tokens` from each prompt.
"""
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
{class}`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:

View File

@ -382,7 +382,8 @@ def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
"""
The logic of adding image pad tokens should only be applied in
{class}`QwenVLProcessor`, so they are patched out here.
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
so they are patched out here.
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py

View File

@ -383,7 +383,7 @@ class _ModelRegistry:
`model_cls` can be either:
- A {class}`torch.nn.Module` class directly referencing the model.
- A [`torch.nn.Module`][] class directly referencing the model.
- A string in the format `<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error

View File

@ -66,7 +66,7 @@ class WeightsMapper:
class AutoWeightsLoader:
"""
Helper class to load weights into a {class}`torch.nn.Module`. It is able
Helper class to load weights into a [`torch.nn.Module`][]. It is able
to automatically detect child modules and parameters while iterating over
the weights only once.

View File

@ -8,11 +8,12 @@ from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""
The global {class}`~MultiModalRegistry` is used by model runners to
dispatch data processing according to the target model.
The global [`MultiModalRegistry`][vllm.multimodal.registry.MultiModalRegistry]
is used by model runners to dispatch data processing according to the target
model.
Info:
[mm-processing][]
[mm_processing](../../../design/mm_processing.html)
"""
__all__ = [

View File

@ -29,14 +29,14 @@ _T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
"""
A {class}`transformers.image_utils.ImageInput` representing a single image
A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`.
"""
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
list[np.ndarray], list["torch.Tensor"]]
"""
A {class}`transformers.image_utils.VideoInput` representing a single video
A `transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`.
"""
@ -48,7 +48,7 @@ item, which can be passed to a HuggingFace `AudioProcessor`.
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
"""
A {class}`transformers.image_utils.ImageInput` representing a single image
A `transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
@ -58,7 +58,7 @@ these are directly passed to the model without HF processing.
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
"""
A {class}`transformers.image_utils.VideoInput` representing a single video
A `transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
@ -108,7 +108,8 @@ MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by {class}`MultiModalDataBuiltins`.
The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
"""
@ -169,7 +170,8 @@ Uses a list instead of a tensor if the dimensions of each element do not match.
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between {data}`NestedTensors` objects."""
"""Equality check between
[`NestedTensors`][vllm.multimodal.inputs.NestedTensors] objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and torch.equal(a, b)
elif isinstance(b, torch.Tensor):
@ -189,7 +191,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
{meth}`MultiModalKwargs.batch`.
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
"""
@ -197,7 +199,7 @@ A dictionary containing nested tensors which have been batched via
class MultiModalFieldElem:
"""
Represents a keyword argument corresponding to a multi-modal item
in {class}`MultiModalKwargs`.
in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
"""
modality: str
@ -208,13 +210,15 @@ class MultiModalFieldElem:
key: str
"""
The key of this field in {class}`MultiModalKwargs`,
The key of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in {class}`MultiModalKwargs`,
The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
i.e. the value of the keyword argument to be passed to the model.
"""
@ -237,7 +241,8 @@ class MultiModalFieldElem:
class BaseMultiModalField(ABC):
"""
Defines how to interpret tensor data belonging to a keyword argument in
{class}`MultiModalKwargs` for multiple multi-modal items, and vice versa.
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
@ -262,10 +267,12 @@ class BaseMultiModalField(ABC):
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct {class}`MultiModalFieldElem` instances to represent
the provided data.
Construct
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
instances to represent the provided data.
This is the inverse of {meth}`reduce_data`.
This is the inverse of
[`reduce_data`][vllm.multimodal.inputs.BaseMultiModalField.reduce_data].
"""
raise NotImplementedError
@ -275,9 +282,11 @@ class BaseMultiModalField(ABC):
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
"""
Merge the data from multiple instances of {class}`MultiModalFieldElem`.
Merge the data from multiple instances of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem].
This is the inverse of {meth}`build_elems`.
This is the inverse of
[`build_elems`][vllm.multimodal.inputs.BaseMultiModalField.build_elems].
"""
field_types = [type(item.field) for item in elems]
if len(set(field_types)) > 1:
@ -290,7 +299,7 @@ class BaseMultiModalField(ABC):
class MultiModalBatchedField(BaseMultiModalField):
"""
Info:
[MultiModalFieldConfig.batched][]
[`MultiModalFieldConfig.batched`][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
"""
def build_elems(
@ -320,8 +329,8 @@ class MultiModalBatchedField(BaseMultiModalField):
class MultiModalFlatField(BaseMultiModalField):
"""
Info:
[MultiModalFieldConfig.flat][]
[MultiModalFieldConfig.flat_from_sizes][]
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
[`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes]
"""
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
dim: int = 0
@ -362,7 +371,7 @@ class MultiModalFlatField(BaseMultiModalField):
class MultiModalSharedField(BaseMultiModalField):
"""
Info:
[MultiModalFieldConfig.shared][]
[`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared]
"""
batch_size: int
@ -508,7 +517,7 @@ class MultiModalFieldConfig:
```
Info:
[MultiModalFieldConfig.flat][]
[`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat]
"""
if size_per_item.ndim != 1:
@ -572,8 +581,10 @@ class MultiModalFieldConfig:
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of {class}`MultiModalFieldElem`
corresponding to a data item in {class}`MultiModalDataItems`.
A collection of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
corresponding to a data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@staticmethod
@ -592,11 +603,13 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
{meth}`~torch.nn.Module.forward`.
[`torch.nn.Module.forward`][].
The metadata `items` enables us to obtain the keyword arguments
corresponding to each data item in {class}`MultiModalDataItems`, via
{meth}`get_item` and {meth}`get_items`.
corresponding to each data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems], via
[`get_item`][vllm.multimodal.inputs.MultiModalKwargs.get_item] and
[`get_items`][vllm.multimodal.inputs.MultiModalKwargs.get_items].
"""
@staticmethod
@ -635,7 +648,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new {class}`MultiModalKwargs` from multiple items."""
"""Construct a new
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs]
from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
@ -800,7 +815,7 @@ A dictionary containing placeholder ranges for each modality.
class MultiModalInputs(TypedDict):
"""
Represents the outputs of
{class}`vllm.multimodal.processing.BaseMultiModalProcessor`,
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
ready to be passed to vLLM internals.
"""
@ -836,7 +851,8 @@ class MultiModalInputs(TypedDict):
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of {class}`vllm.multimodal.EncDecMultiModalProcessor`
Represents the outputs of
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
ready to be passed to vLLM internals.
"""

View File

@ -28,7 +28,8 @@ else:
class ModalityDataItems(ABC, Generic[_T, _I]):
"""
Represents data items for a modality in {class}`MultiModalDataItems`.
Represents data items for a modality in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
def __init__(self, data: _T, modality: str) -> None:
@ -251,15 +252,15 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As {data}`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
such that each entry corresponds to a list.
As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but
normalized such that each entry corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising {exc}`KeyError`
If `strict=False`, return `0` instead of raising [`KeyError`][]
even if the modality is not found.
"""
if modality not in self:
@ -305,8 +306,8 @@ ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
class MultiModalDataParser:
"""
Parses {data}`~vllm.multimodal.inputs.MultiModalDataDict` into
{class}`MultiModalDataItems`.
Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
Args:
target_sr (float, optional): Enables automatic resampling of audio

View File

@ -114,13 +114,14 @@ class PromptUpdateDetails(Generic[_S]):
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
"""
Given {attr}`full`, return a boolean mask of shape `(len(full),)`
indicating which positions of `full` to assign embeddings to.
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
of `full` to assign embeddings to.
`None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling
{class}`SupportsMultiModal.get_multimodal_embeddings`.
[`SupportsMultiModal.get_multimodal_embeddings`][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings].
"""
@staticmethod
@ -159,13 +160,15 @@ PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
The token sequence or text that are part of the update.
If only part of the content corresponds to feature placeholders, you can
use {class}`PromptUpdateDetails` to specify which part.
use [`PromptUpdateDetails`][vllm.multimodal.processing.PromptUpdateDetails] to
specify which part.
"""
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
PromptUpdateInfo]
"""
Given the index of the processed item within {attr}`modality`,
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
@ -260,8 +263,10 @@ class PromptInsertion(PromptUpdate):
insertion: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within {attr}`modality`,
output the token sequence (or text) to insert right after {attr}`target`.
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to insert right after
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
@ -332,8 +337,10 @@ class PromptReplacement(PromptUpdate):
replacement: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within {attr}`modality`,
output the token sequence (or text) to replace {attr}`target`.
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to replace
[`target`][vllm.multimodal.processing.PromptUpdate.target].
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
@ -387,14 +394,16 @@ _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply [full_groupby][] based on modality."""
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
based on modality."""
return full_groupby(values, key=lambda x: x.modality)
@dataclass
class _BoundPromptSequence:
"""
A {data}`_PromptSeq` bound to a tokenizer to automatically
A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)
@ -446,9 +455,11 @@ class _BoundPromptContent:
@dataclass
class BoundPromptUpdate:
"""
A {class}`PromptUpdate` bound to a tokenizer to automatically convert
{attr}`target` and the result of {meth}`get_content` between
token sequence and text representations.
A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound
to a tokenizer to automatically convert
[`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of
[`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content]
between token sequence and text representations.
"""
_origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False)
@ -482,7 +493,8 @@ class BoundPromptUpdate:
def get_content(self, item_idx: int) -> _BoundPromptContent:
"""
Given the index of the processed item within {attr}`modality`,
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
output the token sequence (or text) to update.
"""
content = self.content
@ -1019,7 +1031,8 @@ class ProcessingCache:
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies (see {meth}`get`).
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
@ -1091,7 +1104,8 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as {class}`MultiModalKwargs`.
A collection of hashes with a similar structure as
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
"""
@ -1099,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
Not to be confused with {class}`transformers.ProcessorMixin`.
Not to be confused with `transformers.ProcessorMixin`.
"""
def __init__(self,
@ -1126,10 +1140,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a parser to preprocess multi-modal data items
before passing them to {meth}`_get_hf_mm_data`.
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass
of {class}`MultiModalDataParser` that has additional subparsers.
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
return MultiModalDataParser()
@ -1138,8 +1154,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
"""
Normalize {class}`MultiModalDataDict` to {class}`MultiModalDataItems`
before passing them to {meth}`_get_hf_mm_data`.
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
supported_mm_limits = self.info.get_supported_mm_limits()
@ -1191,7 +1210,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
inputs.
Moreover, this information is critical to determine the token positions
in order to construct {class}`~vllm-multimodal.input.PlaceholderRange`
in order to construct
[`PlaceholderRange`][vllm.multimodal.inputs.PlaceholderRange]
for each multi-modal item.
"""
raise NotImplementedError
@ -1315,7 +1335,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Most HF processors accept prompt text but not prompt tokens.
If the HF processor adds or removes tokens that are not related to
multi-modal data, you should override this method so it is consistent
with the output of {meth}`_apply_hf_processor_text_only` on the
with the output of
[`_apply_hf_processor_text_only`][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_text_only]
on the
corresponding text.
"""
return prompt_tokens
@ -1330,7 +1352,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using
{class}`DummyInputsBuilder` to go along with the multi-modal data.
[`DummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
to go along with the multi-modal data.
"""
mm_counts = mm_items.get_all_counts()

View File

@ -25,7 +25,7 @@ logger = init_logger(__name__)
class ProcessorInputs:
"""
Represents the keyword arguments to
{meth}`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: Union[str, list[int]]
mm_data: MultiModalDataDict

View File

@ -29,7 +29,11 @@ _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a {class}`MultiModalProcessor` instance from the context."""
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
@ -40,7 +44,9 @@ class ProcessingInfoFactory(Protocol[_I_co]):
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a {class}`BaseDummyInputsBuilder` instance from the context.
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
@ -48,7 +54,11 @@ class DummyInputsBuilderFactory(Protocol[_I]):
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a {class}`MultiModalProcessor` instance from the context."""
"""
Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
instance from the context.
"""
def __call__(
self,
@ -155,8 +165,6 @@ class MultiModalRegistry:
"""
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
@ -170,8 +178,6 @@ class MultiModalRegistry:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See {meth}`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
return sum(self.get_max_tokens_by_modality(model_config).values())
@ -213,9 +219,6 @@ class MultiModalRegistry:
When the model receives multi-modal data, the provided function is
invoked to transform the data into a dictionary of model inputs.
Info:
[mm-processing][]
"""
def wrapper(model_cls: N) -> N:
@ -258,9 +261,6 @@ class MultiModalRegistry:
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
Info:
[mm-processing][]
"""
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")

View File

@ -259,7 +259,8 @@ class MediaConnector:
global_media_connector = MediaConnector()
"""The global {class}`MediaConnector` instance used by vLLM."""
"""The global [`MediaConnector`][vllm.multimodal.utils.MediaConnector]
instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image

View File

@ -84,7 +84,7 @@ class DeviceCapability(NamedTuple):
def to_int(self) -> int:
"""
Express device capability as an integer ``<major><minor>``.
Express device capability as an integer `<major><minor>`.
It is assumed that the minor version is always a single digit.
"""
@ -206,10 +206,11 @@ class Platform:
"""
Test whether this platform is compatible with a device capability.
The ``capability`` argument can either be:
The `capability` argument can either be:
- A tuple ``(major, minor)``.
- An integer ``<major><minor>``. (See {meth}`DeviceCapability.to_int`)
- A tuple `(major, minor)`.
- An integer `<major><minor>`. (See
[`DeviceCapability.to_int`][vllm.platforms.interface.DeviceCapability.to_int])
"""
current_capability = cls.get_device_capability(device_id=device_id)
if current_capability is None:

View File

@ -27,7 +27,7 @@ VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int):
"""{class}`array` equivalent of [numpy.full][]."""
"""[`array`][] equivalent of [numpy.full][]."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
@ -192,8 +192,8 @@ class SequenceData(msgspec.Struct,
def from_prompt_token_counts(
*token_counts: tuple[int, int]) -> "SequenceData":
"""
Construct a {class}`SequenceData` instance by concatenating
prompt token sequences.
Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
by concatenating prompt token sequences.
Each tuple represents one token sequence, expressed in the form
`(token_id, count)`.
@ -216,8 +216,8 @@ class SequenceData(msgspec.Struct,
prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData":
"""
Construct a {class}`SequenceData` instance from prompt and output
token sequences.
Construct a [`SequenceData`][vllm.sequence.SequenceData] instance
from prompt and output token sequences.
"""
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)
@ -452,9 +452,11 @@ class SequenceData(msgspec.Struct,
class Sequence:
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the {data}`DecoderOnlyInputs`
(for decoder-only) or {data}`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the `inputs` constructor argument.
The sequence is constructed from the
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only)
or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
(for encoder-decoder) instance passed in through the `inputs`
constructor argument.
Args:
seq_id: The ID of the sequence.

View File

@ -1005,7 +1005,7 @@ def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike {class}`itertools.groupby`, groups are not broken by
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
@ -1926,7 +1926,8 @@ class _PlaceholderBase:
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
{meth}`__getattr__` is not called when they are accessed.
[`__getattr__`][vllm.utils._PlaceholderBase.__getattr__]
is not called when they are accessed.
Info:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)

View File

@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs(
) -> None:
"""
Perform sanity checks for the result of
{meth}`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
[`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
@ -39,7 +39,7 @@ def scatter_mm_placeholders(
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
{class}`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
[`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args:
embeds: The multimodal embeddings.

View File

@ -733,12 +733,13 @@ def _pythonize_sampler_output(
logprobs_tensor: Optional[torch.Tensor],
cache: Optional[PythonizationCache],
) -> None:
""" This function is only called when the output tensors are ready.
See {class}`ModelOutput`.
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
""" This function is only called when the output tensors are ready.
See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput].
Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
adding a Pythonized output data structure
({class}`CompletionSequenceGroupOutput`) for each {class}`SequenceGroup`.
([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput])
for each [`SequenceGroup`][vllm.sequence.SequenceGroup].
Args:
model_input