mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler (#13594)
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@ -31,6 +31,7 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -1305,11 +1306,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
dummy_tensors = lambda v: torch.full(
|
||||
(num_reqs, ), v, device=self.device)
|
||||
dummy_metadata = SamplingMetadata(
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
spec_token_ids=None,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)])
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits, sampling_metadata=dummy_metadata)
|
||||
else:
|
||||
logits = None
|
||||
sampler_output = None
|
||||
dummy_metadata = None
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
del hidden_states, logits, sampler_output, dummy_metadata
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
Reference in New Issue
Block a user