mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[TPU] Deprecate xm.mark_step
in favor of `torch_xla.sync
(#25254)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -77,7 +78,7 @@ def test_pallas_moe(
|
||||
expert_map=e_map,
|
||||
renormalize=False,
|
||||
)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
|
@ -4,6 +4,7 @@ import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_xla
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
|
||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||
masked_prob_sum = probs.sum(dim=-1)
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Perform assertion on CPU.
|
||||
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
|
||||
@ -82,7 +83,7 @@ def test_topp_basic():
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Expect the smallest elements to be dropped.
|
||||
expected_result = logits.clone().cpu()
|
||||
@ -104,7 +105,7 @@ def test_topp_select_all():
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([1.0, 1.0]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
assert torch.allclose(logits.cpu(), result.cpu())
|
||||
|
||||
@ -122,7 +123,7 @@ def test_topp_with_ties():
|
||||
k=torch.tensor([4]),
|
||||
p=torch.tensor([0.2]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# All tie values are included in the top-p set. Tie breaking is left
|
||||
# to be done during final sampling (all tie tokens have equal
|
||||
@ -146,7 +147,7 @@ def test_both_topk_topp():
|
||||
k=torch.tensor([1, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Since for the first batch k=1, expect only the largest element gets
|
||||
# selected.
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla
|
||||
|
||||
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.punica_wrapper.utils import convert_mapping
|
||||
@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
||||
extra_vocab_size: int,
|
||||
):
|
||||
# Make sure we don't accidentally collect outside operations
|
||||
xm.mark_step()
|
||||
torch_xla.sync()
|
||||
|
||||
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
||||
# TODO: Should this happen inside mapping internally? If so how can we
|
||||
|
@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
# In PyTorch XLA, we should call `xm.mark_step`
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program. import torch_xla.core.xla_model
|
||||
# as xm
|
||||
import torch_xla.core.xla_model as xm
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# TPU XLA related
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||
# depending on the input multimodal items.
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||
**mm_kwargs_group)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
# Prepare inputs, the requests might be split into multiple
|
||||
# executions, combine the result of each execution.
|
||||
start_index = 0
|
||||
@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
end_index = self._prepare_inputs(scheduler_output, start_index)
|
||||
input_ids, inputs_embeds = self._get_model_inputs(
|
||||
self.input_ids, mm_embeds)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
# Run the decoder
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
xm.wait_device_ops()
|
||||
if not hasattr(self, "model"):
|
||||
self.model = model
|
||||
@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests) -> None:
|
||||
xm.mark_step() # Captures input updates
|
||||
torch_xla.sync(wait=False) # Captures input updates
|
||||
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
xm.mark_step() # Captures metadata updates
|
||||
torch_xla.sync(wait=False) # Captures metadata updates
|
||||
|
||||
def _precompile_mm_encoder(self) -> None:
|
||||
if not self.supports_mm_inputs:
|
||||
@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_items,
|
||||
)
|
||||
# Run multimodal encoder.
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
mm_embeds = self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
num_patches = mm_embeds[0].shape[0]
|
||||
items_size = num_patches * num_items
|
||||
|
||||
@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
a, b = self._get_model_inputs(placeholders_ids,
|
||||
[mm_embeds])
|
||||
assert a is None
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
# Pre-compile `get_input_embeddings` when mm_embeddings are not
|
||||
# present. Chunk is only made of text, no mm_placeholders.
|
||||
@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
placeholders_ids = placeholders_ids.to(self.device)
|
||||
a, b = self._get_model_inputs(placeholders_ids, [])
|
||||
assert a is None
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Isolate encoder graph from post-processing to minimize
|
||||
# impact of recompilation until it's fixed.
|
||||
start = time.perf_counter()
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
dummy_encoder_outputs = \
|
||||
self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info(
|
||||
@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
|
||||
self.num_blocks_per_most_len_req)
|
||||
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
xm.wait_device_ops()
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
@ -1927,11 +1928,11 @@ def replace_set_lora(model):
|
||||
# to a tensor doesn't seem to work anymore. This might be fixed with a
|
||||
# later release of torch_xla.
|
||||
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
def _tpu_reset_lora(self, index: int):
|
||||
self._original_reset_lora(index)
|
||||
xm.mark_step()
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, BaseLayerWithLoRA):
|
||||
|
Reference in New Issue
Block a user