diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 19ce8c0672..35a5fa0c2e 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -638,7 +638,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
| `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
+| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ |
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index f2c09d3e84..ee546e7af8 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -451,7 +451,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
- extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501
+ extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True,
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index c22d871ab2..2d785c30fd 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -27,12 +27,14 @@ import math
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
+from itertools import chain
from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np
import torch
import torch.types
from torch import nn
+from torch.nn.init import trunc_normal_
from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar
@@ -47,10 +49,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
+from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
+from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
- NestedTensors)
+ MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems,
@@ -218,6 +221,187 @@ class Resampler2_5(BaseResampler):
return x
+class Resampler4_5(Resampler2_5):
+
+ def __init__(self,
+ num_queries: int,
+ embed_dim: int,
+ num_heads: int,
+ kv_dim: Optional[int] = None,
+ norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
+ max_size: tuple[int, int] = (70, 70),
+ max_temporal_size: int = 36000,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "") -> None:
+ super().__init__(num_queries,
+ embed_dim,
+ num_heads,
+ kv_dim,
+ norm_layer,
+ max_size,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ trunc_normal_(self.query, std=.02)
+ self.max_temporal_size = max_temporal_size
+ self._set_temporal_pos_cache(self.max_temporal_size)
+ self.apply(self._init_weights)
+
+ def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int,
+ pos: np.ndarray):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+ def _set_temporal_pos_cache(self,
+ max_temporal_size: int,
+ device: torch.types.Device = "cpu") -> None:
+ temporal_size = np.arange(max_temporal_size, dtype=np.float32)
+ pos_embed = torch.from_numpy(
+ self.get_1d_sincos_pos_embed_from_temporal_size(
+ self.embed_dim, temporal_size)).float().to(device)
+ self.register_buffer("temporal_pos_embed", pos_embed, persistent=False)
+
+ def _adjust_temporal_pos_cache(self,
+ max_temporal_size: int,
+ device: torch.types.Device = "cpu"):
+ if max_temporal_size > self.max_temporal_size:
+ self.max_temporal_size = max_temporal_size
+ self._set_temporal_pos_cache(self.max_temporal_size, device)
+
+ def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ tgt_sizes: torch.Tensor,
+ # temporal_ids for high refresh rate videos
+ temporal_ids=None
+ ) -> torch.Tensor:
+ assert x.shape[0] == tgt_sizes.shape[0]
+ bs = x.shape[0]
+
+ device = x.device
+ dtype = x.dtype
+
+ patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
+
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ temporal_pos_emb = False
+ temporal_ids_flatten = None
+ if temporal_ids is not None:
+ # example: [[-1], [-1], [2, 6, 9]]
+ temporal_ids_flatten = list(chain.from_iterable(temporal_ids))
+ max_temporal_size = max(temporal_ids_flatten, default=0)
+ if max_temporal_size > -1:
+ temporal_pos_emb = True
+ if max_temporal_size > self.max_temporal_size:
+ self._adjust_temporal_pos_cache(max_temporal_size, device)
+
+ max_patch_len = patch_len.max().item()
+ assert isinstance(max_patch_len, int)
+
+ key_padding_mask = torch.zeros((bs, max_patch_len),
+ dtype=torch.bool,
+ device=device)
+
+ x, _ = self.kv_proj(x) # B * L * D
+ x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
+ q = self.ln_q(self.query) # Q * D
+
+ pos_embed_2d = []
+ pos_embed_temporal = []
+ for i in range(bs):
+ tgt_h, tgt_w = tgt_sizes[i]
+ if temporal_pos_emb:
+ if temporal_ids_flatten[i] == -1:
+ pos_embed_temporal.append(
+ torch.zeros(self.embed_dim, dtype=dtype,
+ device=device))
+ else:
+ pos_embed_temporal.append(self.temporal_pos_embed[
+ temporal_ids_flatten[i]].to(dtype)) # D
+
+ pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
+ (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
+ key_padding_mask[i, patch_len[i]:] = True
+
+ pos_embed_2d = torch.nn.utils.rnn.pad_sequence(
+ pos_embed_2d, batch_first=True,
+ padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D
+
+ k = x
+ v = x + pos_embed_2d
+ if pos_embed_temporal:
+ k += torch.stack(pos_embed_temporal, dim=0)
+ bs = len(temporal_ids)
+ merge_k = []
+ merge_v = []
+ merge_key_padding_mask = []
+
+ start = 0
+ for tp in temporal_ids:
+ end = start + len(tp)
+ # L * (end-start) * D -> (end-start) * L * D
+ # -> 1 * L*(end-start) * D
+ merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape(
+ -1, self.embed_dim))
+ merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape(
+ -1, self.embed_dim))
+ merge_key_padding_mask.append(
+ key_padding_mask[start:end, :].reshape(-1, 1))
+
+ start = end
+
+ k = torch.nn.utils.rnn.pad_sequence(merge_k,
+ batch_first=True,
+ padding_value=0.0).permute(
+ 1, 0, 2) # L*(end-start)
+ v = torch.nn.utils.rnn.pad_sequence(merge_v,
+ batch_first=True,
+ padding_value=0.0).permute(
+ 1, 0, 2) # L*(end-start)
+ key_padding_mask = torch.nn.utils.rnn.pad_sequence(
+ merge_key_padding_mask, batch_first=True,
+ padding_value=True).squeeze(-1)
+
+ out = self.attn(
+ self._repeat(q, bs), # Q * B * D
+ k, # L * B * D + L * B * D
+ v,
+ key_padding_mask=key_padding_mask,
+ )[0]
+ # out: Q * B * D
+ x = out.permute(1, 0, 2) # B * Q * D
+
+ x = self.ln_post(x)
+ x = x @ self.proj
+ return x
+
+
def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
version_float = getattr(config, "version", None)
@@ -354,9 +538,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
- if self.get_model_version() == (2,
- 6) or self.get_model_version() == (4,
- 0):
+ if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}:
mm_limits["video"] = None
return mm_limits
@@ -637,8 +819,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys: set[str],
) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
- if self.info.get_model_version() == (
- 2, 6) or self.info.get_model_version() == (4, 0):
+ if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}:
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data,
@@ -816,7 +997,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# and config class
self.config = config
self.multimodal_config = multimodal_config
- self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(vllm_config=vllm_config,
@@ -1364,11 +1544,9 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
- model = Idefics2VisionTransformer(
- config.vision_config,
- quant_config=quant_config,
- prefix=prefix,
- use_data_parallel=self.use_data_parallel)
+ model = Idefics2VisionTransformer(config.vision_config,
+ quant_config=quant_config,
+ prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
@@ -1436,11 +1614,121 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
return loader.load_weights(weights)
+class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ assert self.version == (4, 5)
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
+ return None
+ return quant_config
+
+ def init_llm(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> nn.Module:
+ return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix)
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ quant_config = self._maybe_ignore_quant_config(quant_config)
+ model = Idefics2VisionTransformer(config.vision_config,
+ quant_config=quant_config,
+ prefix=prefix)
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+ return model
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ quant_config = self._maybe_ignore_quant_config(quant_config)
+ with set_default_torch_dtype(torch.float16):
+ # The resampler in 4.0 remains consistent with the one in 2.5/2.6.
+ resampler = Resampler4_5(num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ return resampler.to(device=current_platform.device_type,
+ dtype=torch.get_default_dtype())
+
+ def get_vision_hidden_states(
+ self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
+ pixel_values = data["pixel_values"]
+ tgt_sizes = data["tgt_sizes"]
+ temporal_ids = data.get('temporal_ids', None)
+
+ B = len(pixel_values)
+ P = pixel_values[0].shape[-2]
+ L = max(item.shape[-1] for item in pixel_values)
+ device = pixel_values[0].device
+ dtype = pixel_values[0].dtype
+
+ all_pixel_values = torch.zeros((B, 3, P, L),
+ dtype=dtype,
+ device=device)
+ all_temporal_ids = None if temporal_ids is None else flatten_2d_lists(
+ temporal_ids)
+ for i, pixel_values_item in enumerate(pixel_values):
+ L_item = pixel_values_item.shape[-1]
+ all_pixel_values[i, ..., :L_item] = pixel_values_item
+
+ num_patches = tgt_sizes.prod(-1)
+ max_patches = num_patches.max().item()
+ assert isinstance(max_patches, int)
+
+ patch_attn_mask = torch.zeros((B, max_patches),
+ dtype=torch.bool,
+ device=device)
+ for i, num_patches_item in enumerate(num_patches):
+ patch_attn_mask[i, :num_patches_item] = True
+
+ vision_embedding = self.vpm(
+ all_pixel_values,
+ patch_attention_mask=patch_attn_mask.unsqueeze(1),
+ tgt_sizes=tgt_sizes,
+ )
+
+ return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self,
+ skip_prefixes=["apm.", "audio", "tts"])
+ return loader.load_weights(weights)
+
+
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6,
(4, 0): MiniCPMV4_0,
+ (4, 5): MiniCPMV4_5,
}
diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py
index e0ef7f0999..d09c5fa924 100644
--- a/vllm/transformers_utils/chat_templates/registry.py
+++ b/vllm/transformers_utils/chat_templates/registry.py
@@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback(
return CHAT_TEMPLATES_DIR / "template_basic.jinja"
+def _get_minicpmv_chat_template_fallback(
+ tokenizer_name_or_path: str) -> Optional[Path]:
+ # MiniCPM-V-4.5 version uses a dedicated template
+ if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path:
+ return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja"
+
+ # Other versions use chatml template
+ return CHAT_TEMPLATES_DIR / "template_chatml.jinja"
+
+
# yapf: disable
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
@@ -27,6 +37,7 @@ _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
"florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
+ "minicpmv": _get_minicpmv_chat_template_fallback,
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"qwen": _get_qwen_chat_template_fallback,
}
diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja
new file mode 100644
index 0000000000..661ebd1cf5
--- /dev/null
+++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja
@@ -0,0 +1,93 @@
+{%- set enable_thinking = enable_thinking | default(false) %}
+{%- if tools %}
+ {{- '<|im_start|>system\n' }}
+ {%- if messages[0].role == 'system' %}
+ {{- messages[0].content + '\n\n' }}
+ {%- endif %}
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }}
+ {%- for tool in tools %}
+ {{- "\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }}
+{%- else %}
+ {%- if messages[0].role == 'system' %}
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
+ {%- endif %}
+{%- endif %}
+
+{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
+{%- for message in messages[::-1] %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set ns.multi_step_tool = false %}
+ {%- set ns.last_query_index = index %}
+ {%- endif %}
+{%- endfor %}
+
+{%- for message in messages %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
+ {%- elif message.role == "assistant" %}
+ {%- set content = message.content %}
+ {%- set reasoning_content = '' %}
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
+ {%- set reasoning_content = message.reasoning_content %}
+ {%- else %}
+ {%- if '' in message.content %}
+ {%- set content = message.content.split('')[-1].lstrip('\n') %}
+ {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %}
+ {%- endif %}
+ {%- endif %}
+ {%- if loop.index0 > ns.last_query_index %}
+ {%- if loop.last or (not loop.last and reasoning_content) %}
+ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }}
+ {%- else %}
+ {{- '<|im_start|>' + message.role + '\n' + content }}
+ {%- endif %}
+ {%- else %}
+ {{- '<|im_start|>' + message.role + '\n' + content }}
+ {%- endif %}
+
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- message.content }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\n' }}
+ {%- if enable_thinking is defined and enable_thinking is false %}
+ {{- '\n\n\n\n' }}
+ {%- endif %}
+ {%- if enable_thinking is defined and enable_thinking is true %}
+ {{- '\n' }}
+ {%- endif %}
+{%- endif %}
\ No newline at end of file