Compare commits

...

4 Commits

Author SHA1 Message Date
4c42267293 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-03-28 02:26:20 +00:00
24f68342b4 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-03-28 02:17:42 +00:00
c5d963835b updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-03-28 01:54:01 +00:00
b313220727 updates
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-03-27 23:51:36 +00:00
3 changed files with 48 additions and 18 deletions

View File

@ -15,7 +15,6 @@ import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
_shutdown_backend,
_unregister_process_group,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
@ -343,5 +342,7 @@ def stateless_destroy_torch_distributed_process_group(
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
_unregister_process_group(pg.group_name)

View File

@ -104,6 +104,7 @@ if TYPE_CHECKING:
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_DISABLE_SAMPLER_DEBUG: bool = False
def get_default_cache_root():
@ -673,6 +674,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
# Disable sampler path for debugging performance.
"VLLM_TPU_DISABLE_SAMPLER_DEBUG":
lambda: os.environ.get("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "0") == "1",
}
# end-env-vars-definition

View File

@ -599,23 +599,35 @@ class TPUModelRunner:
input_ids = self.input_ids
inputs_embeds = None
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Temporary debug pathway.
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.compute_logits_no_sampler(
hidden_states, logits_indices, None)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
else:
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
@ -929,6 +941,18 @@ class ModelWrapperV1(nn.Module):
logits = self.model.compute_logits(hidden_states, None)
return logits
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits_no_sampler(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)