mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Cache uses_mrope
in GPUModelRunner (#12969)
This commit is contained in:
@ -92,6 +92,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
@ -147,7 +148,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
if self.uses_mrope:
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
@ -284,7 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
if self.uses_mrope:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
@ -411,7 +412,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
@ -458,7 +459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
if self.model_config.uses_mrope:
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
@ -817,13 +818,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
positions = self.mrope_positions[:, :num_input_tokens] \
|
||||
if self.model_config.uses_mrope \
|
||||
else self.positions[:num_input_tokens]
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -1001,10 +1003,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
positions = self.mrope_positions[:, :num_tokens] \
|
||||
if self.model_config.uses_mrope \
|
||||
else self.positions[:num_tokens]
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
Reference in New Issue
Block a user