mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
4 Commits
v0.9.1rc1
...
sampler-en
Author | SHA1 | Date | |
---|---|---|---|
4c42267293 | |||
24f68342b4 | |||
c5d963835b | |||
b313220727 |
@ -15,7 +15,6 @@ import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
_shutdown_backend,
|
||||
_unregister_process_group,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
@ -343,5 +342,7 @@ def stateless_destroy_torch_distributed_process_group(
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
# Lazy import for non-CUDA backends.
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
_shutdown_backend(pg)
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
@ -104,6 +104,7 @@ if TYPE_CHECKING:
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_DISABLE_SAMPLER_DEBUG: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -673,6 +674,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
||||
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
|
||||
|
||||
# Disable sampler path for debugging performance.
|
||||
"VLLM_TPU_DISABLE_SAMPLER_DEBUG":
|
||||
lambda: os.environ.get("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "0") == "1",
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -599,23 +599,35 @@ class TPUModelRunner:
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Temporary debug pathway.
|
||||
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.compute_logits_no_sampler(
|
||||
hidden_states, logits_indices, None)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
else:
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
@ -929,6 +941,18 @@ class ModelWrapperV1(nn.Module):
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
return logits
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def compute_logits_no_sampler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits_indices: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
return selected_token_ids
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user