[Intel GPU] Fix xpu decode input (#9145)

This commit is contained in:
Kunshang Ji
2024-10-08 11:51:14 +08:00
committed by GitHub
parent 04c12f8157
commit 80b57f00d5

View File

@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
@ -136,7 +137,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = []
seq_lens = None
multi_modal_kwargs = None
return self.model_input_cls(
@ -390,6 +391,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(
@ -524,12 +529,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,