[TPU] Deprecate xm.mark_step in favor of `torch_xla.sync (#25254)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Nicolò Lucchesi
2025-09-22 12:12:57 +02:00
committed by GitHub
parent a66d131381
commit 4cf71cc88a
5 changed files with 31 additions and 29 deletions

View File

@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
import torch_xla
# yapf conflicts with isort for this block
# yapf: disable
@ -77,7 +78,7 @@ def test_pallas_moe(
expert_map=e_map,
renormalize=False,
)
xm.mark_step()
torch_xla.sync(wait=False)
# Compare outputs
torch.testing.assert_close(

View File

@ -4,6 +4,7 @@ import math
import pytest
import torch
import torch_xla
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
@ -63,7 +64,7 @@ def test_topp_result_sums_past_p():
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
xm.mark_step()
torch_xla.sync()
# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))
@ -82,7 +83,7 @@ def test_topp_basic():
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
torch_xla.sync()
# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
@ -104,7 +105,7 @@ def test_topp_select_all():
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))
xm.mark_step()
torch_xla.sync()
assert torch.allclose(logits.cpu(), result.cpu())
@ -122,7 +123,7 @@ def test_topp_with_ties():
k=torch.tensor([4]),
p=torch.tensor([0.2]))
xm.mark_step()
torch_xla.sync()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
@ -146,7 +147,7 @@ def test_both_topk_topp():
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))
xm.mark_step()
torch_xla.sync()
# Since for the first batch k=1, expect only the largest element gets
# selected.

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
extra_vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
xm.mark_step()
torch_xla.sync()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we

View File

@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader):
from vllm.platforms.tpu import USE_TPU_COMMONS
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm
# in the XLA program.
import torch_xla
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator)

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
import torch.nn as nn
# TPU XLA related
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
xm.mark_step()
torch_xla.sync(wait=False)
curr_group_outputs = self.model.get_multimodal_embeddings(
**mm_kwargs_group)
xm.mark_step()
torch_xla.sync(wait=False)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
xm.mark_step()
torch_xla.sync(wait=False)
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
start_index = 0
@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds)
xm.mark_step()
torch_xla.sync(wait=False)
# Run the decoder
with set_forward_context(
attn_metadata,
@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
torch_xla.sync(wait=False)
xm.wait_device_ops()
if not hasattr(self, "model"):
self.model = model
@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping,
lora_requests) -> None:
xm.mark_step() # Captures input updates
torch_xla.sync(wait=False) # Captures input updates
super()._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
xm.mark_step() # Captures metadata updates
torch_xla.sync(wait=False) # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
if not self.supports_mm_inputs:
@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_items,
)
# Run multimodal encoder.
xm.mark_step()
torch_xla.sync(wait=False)
mm_embeds = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
torch_xla.sync(wait=False)
num_patches = mm_embeds[0].shape[0]
items_size = num_patches * num_items
@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
a, b = self._get_model_inputs(placeholders_ids,
[mm_embeds])
assert a is None
xm.mark_step()
torch_xla.sync(wait=False)
# Pre-compile `get_input_embeddings` when mm_embeddings are not
# present. Chunk is only made of text, no mm_placeholders.
@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(placeholders_ids, [])
assert a is None
xm.mark_step()
torch_xla.sync(wait=False)
xm.wait_device_ops()
end = time.perf_counter()
@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Isolate encoder graph from post-processing to minimize
# impact of recompilation until it's fixed.
start = time.perf_counter()
xm.mark_step()
torch_xla.sync(wait=False)
dummy_encoder_outputs = \
self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
xm.mark_step()
torch_xla.sync(wait=False)
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req)
xm.mark_step()
torch_xla.sync(wait=False)
xm.wait_device_ops()
self.encoder_cache.clear()
gc.collect()
@ -1927,11 +1928,11 @@ def replace_set_lora(model):
# to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias)
xm.mark_step()
torch_xla.sync(wait=False)
def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index)
xm.mark_step()
torch_xla.sync(wait=False)
for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA):