Compare commits

...

2 Commits

Author SHA1 Message Date
5873877241 [Bugfix] Mistral tool calling when content is list (#18729)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-05-27 09:05:37 -07:00
696259ca01 [Core] Automatically cast multi-modal input dtype (#18756)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-05-27 23:45:48 +08:00
18 changed files with 206 additions and 51 deletions

View File

@ -1,15 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.messages import (AssistantMessage,
ToolMessage,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.tool_calls import (Function,
FunctionCall, Tool,
ToolCall)
from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request)
# yapf: enable
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[(
@ -78,6 +81,107 @@ from vllm.transformers_utils.tokenizers.mistral import (
)
def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request):
assert (make_mistral_chat_completion_request(
openai_request["messages"],
openai_request["tools"]) == expected_mistral_request)
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"])
assert actual_request == expected_mistral_request
# Tool use with list content and reasoning_content
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
{
"messages": [
{
"role": "user",
"content": "What's the weather in Paris?",
},
{
"role":
"assistant",
"reasoning_content":
None,
"content":
None,
"tool_calls": [{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}],
},
{
"role": "tool",
"content": [{
"type": "text",
"text": "Rainy"
}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"],
},
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What's the weather in Paris?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="call123",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
ToolMessage(
content="Rainy",
tool_call_id="call123",
name="get_weather",
),
],
tools=[
Tool(
type="function",
function=Function(
name="get_weather",
description="Gets the current weather in a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"],
},
),
)
],
),
)])
def test_make_mistral_chat_completion_request_list_content(
openai_request, expected_mistral_request):
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"])
assert actual_request == expected_mistral_request

View File

@ -210,9 +210,7 @@ class DeepseekVL2MultiModalProcessor(
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
target_dtype = self.info.ctx.model_config.dtype
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [

View File

@ -263,11 +263,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data,
mm_kwargs,
)
if "pixel_values" in processed_outputs:
# Cast pixel values to model dtype already here,
# so we need to transfer less data to the GPU
processed_outputs["pixel_values"] = processed_outputs[
"pixel_values"].to(self.info.ctx.model_config.dtype)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:

View File

@ -746,11 +746,17 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
def maybe_cast_dtype(x: torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
return x.to(dtype=dtype) if x.is_floating_point() else x
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
# NOTE: Cast the dtype before sending it to device
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
json_inputs,
)

View File

@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,
)

View File

@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages:
if message.get("role") == "assistant":
# Remove reasoning_content as unsupported by Mistral
_ = message.pop("reasoning_content", None) # type: ignore
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)

View File

@ -929,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder.
# `curr_group_outputs` is either of the following:
@ -1874,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(

View File

@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder.
# `curr_group_outputs` is either of the following:
@ -1435,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size)
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
device=self.device)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:

View File

@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}

View File

@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(

View File

@ -50,8 +50,11 @@ class CPUPoolingModelRunner(
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,

View File

@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
)
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

View File

@ -1845,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
**model_kwargs,
)

View File

@ -70,8 +70,11 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
output = self.model.sample(

View File

@ -49,8 +49,11 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
output = self.model.sample(

View File

@ -378,9 +378,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
@ -389,9 +391,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
# Compute the logits only if the on-device sampling is turned off as

View File

@ -119,10 +119,14 @@ class PoolingModelRunner(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)
**seqlen_agnostic_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):

View File

@ -562,9 +562,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states