[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#3099)

This commit is contained in:
Nick Hill
2024-02-29 11:20:42 -08:00
committed by GitHub
parent 2c08ff23c0
commit 29a8d6a554
2 changed files with 18 additions and 2 deletions

View File

@ -484,8 +484,9 @@ class LLMEngine:
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler
sampling_params = copy.deepcopy(sampling_params)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,

View File

@ -1,4 +1,5 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from functools import cached_property
from typing import Callable, List, Optional, Union
@ -237,6 +238,20 @@ class SamplingParams:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs = None if self.logits_processors is None else {
id(lp): lp
for lp in self.logits_processors
}
return copy.deepcopy(self, memo=logit_processor_refs)
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "