diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 407a824d81..1e5d9d923d 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -6,6 +6,7 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`. """ import pytest import torch +import torch_xla # yapf conflicts with isort for this block # yapf: disable @@ -77,7 +78,7 @@ def test_pallas_moe( expert_map=e_map, renormalize=False, ) - xm.mark_step() + torch_xla.sync(wait=False) # Compare outputs torch.testing.assert_close( diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index 05751badc7..665cf8cd26 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -4,6 +4,7 @@ import math import pytest import torch +import torch_xla from vllm.platforms import current_platform from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p @@ -63,7 +64,7 @@ def test_topp_result_sums_past_p(): probs.masked_fill_(logits_masked.isinf(), 0) masked_prob_sum = probs.sum(dim=-1) - xm.mark_step() + torch_xla.sync() # Perform assertion on CPU. assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) @@ -82,7 +83,7 @@ def test_topp_basic(): k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])) - xm.mark_step() + torch_xla.sync() # Expect the smallest elements to be dropped. expected_result = logits.clone().cpu() @@ -104,7 +105,7 @@ def test_topp_select_all(): k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])) - xm.mark_step() + torch_xla.sync() assert torch.allclose(logits.cpu(), result.cpu()) @@ -122,7 +123,7 @@ def test_topp_with_ties(): k=torch.tensor([4]), p=torch.tensor([0.2])) - xm.mark_step() + torch_xla.sync() # All tie values are included in the top-p set. Tie breaking is left # to be done during final sampling (all tie tokens have equal @@ -146,7 +147,7 @@ def test_both_topk_topp(): k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])) - xm.mark_step() + torch_xla.sync() # Since for the first batch k=1, expect only the largest element gets # selected. diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 07dc337a1c..5896da5165 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import torch.nn.functional as F -import torch_xla.core.xla_model as xm +import torch_xla from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from vllm.lora.punica_wrapper.utils import convert_mapping @@ -323,7 +323,7 @@ class PunicaWrapperTPU(PunicaWrapperBase): extra_vocab_size: int, ): # Make sure we don't accidentally collect outside operations - xm.mark_step() + torch_xla.sync() # Pad the prompt mapping to avoid running into recompiles on the TPU # TODO: Should this happen inside mapping internally? If so how can we diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index d1bdec21fd..4b7bcd37d4 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -211,16 +211,15 @@ class DefaultModelLoader(BaseModelLoader): from vllm.platforms.tpu import USE_TPU_COMMONS if not USE_TPU_COMMONS: - # In PyTorch XLA, we should call `xm.mark_step` + # In PyTorch XLA, we should call `torch_xla.sync` # frequently so that not too many ops are accumulated - # in the XLA program. import torch_xla.core.xla_model - # as xm - import torch_xla.core.xla_model as xm + # in the XLA program. + import torch_xla def _xla_weights_iterator(iterator: Generator): for weights in iterator: yield weights - xm.mark_step() + torch_xla.sync(wait=False) weights_iterator = _xla_weights_iterator(weights_iterator) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index dd11b1dcbe..4cbf991a14 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -10,6 +10,7 @@ import numpy as np import torch import torch.nn as nn # TPU XLA related +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs import torch_xla.runtime as xr @@ -846,10 +847,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. - xm.mark_step() + torch_xla.sync(wait=False) curr_group_outputs = self.model.get_multimodal_embeddings( **mm_kwargs_group) - xm.mark_step() + torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -952,7 +953,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - xm.mark_step() + torch_xla.sync(wait=False) # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. start_index = 0 @@ -969,7 +970,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): end_index = self._prepare_inputs(scheduler_output, start_index) input_ids, inputs_embeds = self._get_model_inputs( self.input_ids, mm_embeds) - xm.mark_step() + torch_xla.sync(wait=False) # Run the decoder with set_forward_context( attn_metadata, @@ -1183,7 +1184,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sync all pending XLA execution during model initialization and weight # loading. - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() if not hasattr(self, "model"): self.model = model @@ -1267,10 +1268,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: - xm.mark_step() # Captures input updates + torch_xla.sync(wait=False) # Captures input updates super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) - xm.mark_step() # Captures metadata updates + torch_xla.sync(wait=False) # Captures metadata updates def _precompile_mm_encoder(self) -> None: if not self.supports_mm_inputs: @@ -1297,10 +1298,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_items, ) # Run multimodal encoder. - xm.mark_step() + torch_xla.sync(wait=False) mm_embeds = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) - xm.mark_step() + torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1325,7 +1326,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): a, b = self._get_model_inputs(placeholders_ids, [mm_embeds]) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) # Pre-compile `get_input_embeddings` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. @@ -1336,7 +1337,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): placeholders_ids = placeholders_ids.to(self.device) a, b = self._get_model_inputs(placeholders_ids, []) assert a is None - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() @@ -1532,11 +1533,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Isolate encoder graph from post-processing to minimize # impact of recompilation until it's fixed. start = time.perf_counter() - xm.mark_step() + torch_xla.sync(wait=False) dummy_encoder_outputs = \ self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( @@ -1559,7 +1560,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._dummy_run(num_tokens, self.num_reqs_most_model_len, self.num_blocks_per_most_len_req) - xm.mark_step() + torch_xla.sync(wait=False) xm.wait_device_ops() self.encoder_cache.clear() gc.collect() @@ -1927,11 +1928,11 @@ def replace_set_lora(model): # to a tensor doesn't seem to work anymore. This might be fixed with a # later release of torch_xla. self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) - xm.mark_step() + torch_xla.sync(wait=False) def _tpu_reset_lora(self, index: int): self._original_reset_lora(index) - xm.mark_step() + torch_xla.sync(wait=False) for _, module in model.named_modules(): if isinstance(module, BaseLayerWithLoRA):