mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#3099)
This commit is contained in:
@ -484,8 +484,9 @@ class LLMEngine:
|
||||
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
|
||||
if lora_request else 0) if prefix_pos is not None else None
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler
|
||||
sampling_params = copy.deepcopy(sampling_params)
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
@ -237,6 +238,20 @@ class SamplingParams:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
|
||||
LogitsProcessor objects are excluded because they may contain an
|
||||
arbitrary, nontrivial amount of data.
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
"""
|
||||
|
||||
logit_processor_refs = None if self.logits_processors is None else {
|
||||
id(lp): lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
|
Reference in New Issue
Block a user