mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Intel GPU] Fix xpu decode input (#9145)
This commit is contained in:
@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadataCache
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
@ -136,7 +137,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(
|
||||
self.seq_group_metadata_list)
|
||||
seq_lens = []
|
||||
seq_lens = None
|
||||
multi_modal_kwargs = None
|
||||
|
||||
return self.model_input_cls(
|
||||
@ -390,6 +391,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
def load_model(self) -> None:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
@ -524,12 +529,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators,
|
||||
cache=self.sampling_metadata_cache)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
|
Reference in New Issue
Block a user